1
0
Files
Mwata-Velu_et_al_2023/data_loading.py
2026-04-09 08:21:30 -07:00

119 lines
3.1 KiB
Python

"""
Data loading utilities for the curated PhysioNet EEG Motor Movement/Imagery
dataset (CSV format).
"""
import numpy as np
import pandas as pd
import mne
from config import DATA_DIR, N_CHANNELS, CHANNEL_NAMES, SAMPLING_RATE
def load_signal_file(subject_num, run_num, data_dir=DATA_DIR):
"""
Load EEG signal from CSV file.
Returns
-------
np.ndarray of shape (n_samples, N_CHANNELS)
"""
filename = f"SUB_{subject_num:03d}_SIG_{run_num:02d}.csv"
filepath = data_dir / filename
if not filepath.exists():
raise FileNotFoundError(f"File not found: {filepath}")
data = pd.read_csv(filepath, header=None).values
if data.shape[1] != N_CHANNELS:
raise ValueError(f"Expected {N_CHANNELS} channels, got {data.shape[1]}")
return data
def load_annotation_file(subject_num, run_num, data_dir=DATA_DIR):
"""
Load annotations from CSV file.
Returns
-------
pd.DataFrame with columns ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
"""
filename = f"SUB_{subject_num:03d}_ANN_{run_num:02d}.csv"
filepath = data_dir / filename
if not filepath.exists():
raise FileNotFoundError(f"File not found: {filepath}")
annotations = pd.read_csv(filepath)
expected_cols = ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
if list(annotations.columns) != expected_cols:
if len(annotations.columns) == 5:
annotations.columns = expected_cols
else:
raise ValueError(
f"Expected 5 columns {expected_cols}, got {len(annotations.columns)}: "
f"{list(annotations.columns)}"
)
return annotations
def create_mne_raw(signal_data, sfreq=SAMPLING_RATE):
"""
Create MNE Raw object from signal data.
Parameters
----------
signal_data : np.ndarray of shape (n_samples, N_CHANNELS)
Returns
-------
mne.io.RawArray
"""
# Transpose to (channels, samples) and convert µV → V
data = signal_data.T * 1e-6
info = mne.create_info(
ch_names=CHANNEL_NAMES,
sfreq=sfreq,
ch_types='eeg',
)
raw = mne.io.RawArray(data, info, verbose=False)
# Set standard 10-20 montage
montage = mne.channels.make_standard_montage('standard_1005')
# Only set channels that exist in the montage
valid_chs = [ch for ch in raw.ch_names if ch in montage.ch_names]
raw = raw.pick(valid_chs)
raw.set_montage(montage, on_missing='ignore')
return raw
def add_annotations_to_raw(raw, annotations_df):
"""
Add event annotations to MNE Raw object.
Parameters
----------
raw : mne.io.Raw
annotations_df : pd.DataFrame with start_row, duration_sec, label columns
Returns
-------
mne.io.Raw with annotations added
"""
sfreq = raw.info['sfreq']
onsets = annotations_df['start_row'].values / sfreq
durations = annotations_df['duration_sec'].values
descriptions = [str(int(label)) for label in annotations_df['label'].values]
annot = mne.Annotations(onset=onsets, duration=durations,
description=descriptions)
raw.set_annotations(annot)
return raw