Initial commit
This commit is contained in:
118
data_loading.py
Normal file
118
data_loading.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""
|
||||
Data loading utilities for the curated PhysioNet EEG Motor Movement/Imagery
|
||||
dataset (CSV format).
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import mne
|
||||
|
||||
from config import DATA_DIR, N_CHANNELS, CHANNEL_NAMES, SAMPLING_RATE
|
||||
|
||||
|
||||
def load_signal_file(subject_num, run_num, data_dir=DATA_DIR):
|
||||
"""
|
||||
Load EEG signal from CSV file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
np.ndarray of shape (n_samples, N_CHANNELS)
|
||||
"""
|
||||
filename = f"SUB_{subject_num:03d}_SIG_{run_num:02d}.csv"
|
||||
filepath = data_dir / filename
|
||||
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(f"File not found: {filepath}")
|
||||
|
||||
data = pd.read_csv(filepath, header=None).values
|
||||
|
||||
if data.shape[1] != N_CHANNELS:
|
||||
raise ValueError(f"Expected {N_CHANNELS} channels, got {data.shape[1]}")
|
||||
|
||||
return data
|
||||
|
||||
|
||||
def load_annotation_file(subject_num, run_num, data_dir=DATA_DIR):
|
||||
"""
|
||||
Load annotations from CSV file.
|
||||
|
||||
Returns
|
||||
-------
|
||||
pd.DataFrame with columns ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
|
||||
"""
|
||||
filename = f"SUB_{subject_num:03d}_ANN_{run_num:02d}.csv"
|
||||
filepath = data_dir / filename
|
||||
|
||||
if not filepath.exists():
|
||||
raise FileNotFoundError(f"File not found: {filepath}")
|
||||
|
||||
annotations = pd.read_csv(filepath)
|
||||
|
||||
expected_cols = ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
|
||||
if list(annotations.columns) != expected_cols:
|
||||
if len(annotations.columns) == 5:
|
||||
annotations.columns = expected_cols
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Expected 5 columns {expected_cols}, got {len(annotations.columns)}: "
|
||||
f"{list(annotations.columns)}"
|
||||
)
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
def create_mne_raw(signal_data, sfreq=SAMPLING_RATE):
|
||||
"""
|
||||
Create MNE Raw object from signal data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
signal_data : np.ndarray of shape (n_samples, N_CHANNELS)
|
||||
|
||||
Returns
|
||||
-------
|
||||
mne.io.RawArray
|
||||
"""
|
||||
# Transpose to (channels, samples) and convert µV → V
|
||||
data = signal_data.T * 1e-6
|
||||
|
||||
info = mne.create_info(
|
||||
ch_names=CHANNEL_NAMES,
|
||||
sfreq=sfreq,
|
||||
ch_types='eeg',
|
||||
)
|
||||
raw = mne.io.RawArray(data, info, verbose=False)
|
||||
|
||||
# Set standard 10-20 montage
|
||||
montage = mne.channels.make_standard_montage('standard_1005')
|
||||
# Only set channels that exist in the montage
|
||||
valid_chs = [ch for ch in raw.ch_names if ch in montage.ch_names]
|
||||
raw = raw.pick(valid_chs)
|
||||
raw.set_montage(montage, on_missing='ignore')
|
||||
|
||||
return raw
|
||||
|
||||
|
||||
def add_annotations_to_raw(raw, annotations_df):
|
||||
"""
|
||||
Add event annotations to MNE Raw object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
raw : mne.io.Raw
|
||||
annotations_df : pd.DataFrame with start_row, duration_sec, label columns
|
||||
|
||||
Returns
|
||||
-------
|
||||
mne.io.Raw with annotations added
|
||||
"""
|
||||
sfreq = raw.info['sfreq']
|
||||
|
||||
onsets = annotations_df['start_row'].values / sfreq
|
||||
durations = annotations_df['duration_sec'].values
|
||||
descriptions = [str(int(label)) for label in annotations_df['label'].values]
|
||||
|
||||
annot = mne.Annotations(onset=onsets, duration=durations,
|
||||
description=descriptions)
|
||||
raw.set_annotations(annot)
|
||||
return raw
|
||||
Reference in New Issue
Block a user