119 lines
3.1 KiB
Python
119 lines
3.1 KiB
Python
"""
|
|
Data loading utilities for the curated PhysioNet EEG Motor Movement/Imagery
|
|
dataset (CSV format).
|
|
"""
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import mne
|
|
|
|
from config import DATA_DIR, N_CHANNELS, CHANNEL_NAMES, SAMPLING_RATE
|
|
|
|
|
|
def load_signal_file(subject_num, run_num, data_dir=DATA_DIR):
|
|
"""
|
|
Load EEG signal from CSV file.
|
|
|
|
Returns
|
|
-------
|
|
np.ndarray of shape (n_samples, N_CHANNELS)
|
|
"""
|
|
filename = f"SUB_{subject_num:03d}_SIG_{run_num:02d}.csv"
|
|
filepath = data_dir / filename
|
|
|
|
if not filepath.exists():
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
data = pd.read_csv(filepath, header=None).values
|
|
|
|
if data.shape[1] != N_CHANNELS:
|
|
raise ValueError(f"Expected {N_CHANNELS} channels, got {data.shape[1]}")
|
|
|
|
return data
|
|
|
|
|
|
def load_annotation_file(subject_num, run_num, data_dir=DATA_DIR):
|
|
"""
|
|
Load annotations from CSV file.
|
|
|
|
Returns
|
|
-------
|
|
pd.DataFrame with columns ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
|
|
"""
|
|
filename = f"SUB_{subject_num:03d}_ANN_{run_num:02d}.csv"
|
|
filepath = data_dir / filename
|
|
|
|
if not filepath.exists():
|
|
raise FileNotFoundError(f"File not found: {filepath}")
|
|
|
|
annotations = pd.read_csv(filepath)
|
|
|
|
expected_cols = ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
|
|
if list(annotations.columns) != expected_cols:
|
|
if len(annotations.columns) == 5:
|
|
annotations.columns = expected_cols
|
|
else:
|
|
raise ValueError(
|
|
f"Expected 5 columns {expected_cols}, got {len(annotations.columns)}: "
|
|
f"{list(annotations.columns)}"
|
|
)
|
|
|
|
return annotations
|
|
|
|
|
|
def create_mne_raw(signal_data, sfreq=SAMPLING_RATE):
|
|
"""
|
|
Create MNE Raw object from signal data.
|
|
|
|
Parameters
|
|
----------
|
|
signal_data : np.ndarray of shape (n_samples, N_CHANNELS)
|
|
|
|
Returns
|
|
-------
|
|
mne.io.RawArray
|
|
"""
|
|
# Transpose to (channels, samples) and convert µV → V
|
|
data = signal_data.T * 1e-6
|
|
|
|
info = mne.create_info(
|
|
ch_names=CHANNEL_NAMES,
|
|
sfreq=sfreq,
|
|
ch_types='eeg',
|
|
)
|
|
raw = mne.io.RawArray(data, info, verbose=False)
|
|
|
|
# Set standard 10-20 montage
|
|
montage = mne.channels.make_standard_montage('standard_1005')
|
|
# Only set channels that exist in the montage
|
|
valid_chs = [ch for ch in raw.ch_names if ch in montage.ch_names]
|
|
raw = raw.pick(valid_chs)
|
|
raw.set_montage(montage, on_missing='ignore')
|
|
|
|
return raw
|
|
|
|
|
|
def add_annotations_to_raw(raw, annotations_df):
|
|
"""
|
|
Add event annotations to MNE Raw object.
|
|
|
|
Parameters
|
|
----------
|
|
raw : mne.io.Raw
|
|
annotations_df : pd.DataFrame with start_row, duration_sec, label columns
|
|
|
|
Returns
|
|
-------
|
|
mne.io.Raw with annotations added
|
|
"""
|
|
sfreq = raw.info['sfreq']
|
|
|
|
onsets = annotations_df['start_row'].values / sfreq
|
|
durations = annotations_df['duration_sec'].values
|
|
descriptions = [str(int(label)) for label in annotations_df['label'].values]
|
|
|
|
annot = mne.Annotations(onset=onsets, duration=durations,
|
|
description=descriptions)
|
|
raw.set_annotations(annot)
|
|
return raw
|