Initial commit
This commit is contained in:
65
README.md
Normal file
65
README.md
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
# Reproduction: Mwata-Velu et al. (2023)
|
||||||
|
|
||||||
|
**Paper**: "EEG-BCI Features Discrimination between Executed and Imagined Movements
|
||||||
|
Based on FastICA, Hjorth Parameters, and SVM"
|
||||||
|
**Journal**: Mathematics 2023, 11, 4409
|
||||||
|
**DOI**: https://doi.org/10.3390/math11214409
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
This repository contains an attempted reproduction of the above paper as part of a
|
||||||
|
thesis on reproducibility challenges in EEG-based BCI research. The reproduction was
|
||||||
|
**partially completed** — the core pipeline is implemented but several ambiguities in the
|
||||||
|
paper prevented a definitive reproduction.
|
||||||
|
|
||||||
|
## Repository Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
config.py — Constants, channel lists, run mappings, parameters
|
||||||
|
data_loading.py — CSV data loading, MNE Raw creation, annotations
|
||||||
|
pipeline.py — Filtering, FastICA, Hjorth features, SVM classification
|
||||||
|
reproduction_notebook.ipynb — Main analysis notebook (Method 2: cross-subject)
|
||||||
|
requirements.txt — Python dependencies
|
||||||
|
```
|
||||||
|
|
||||||
|
## Data
|
||||||
|
|
||||||
|
This code expects the PhysioNet EEG Motor Movement/Imagery Dataset in the curated
|
||||||
|
CSV format provided by:
|
||||||
|
Z. Shuqfa, A. Lakas, and A. N. Belkacem, “Increasing accessibility to a large brain–
|
||||||
|
computer interface dataset: Curation of physionet EEG motor movement/imagery dataset
|
||||||
|
for decoding and classification,” Data in Brief, vol. 54, p. 110181, Jun. 2024,
|
||||||
|
doi: 10.1016/j.dib.2024.110181.
|
||||||
|
|
||||||
|
Files are named:
|
||||||
|
- `eegmmidb/SUB_001_SIG_01.csv` — Signal data (n_samples × 64 channels)
|
||||||
|
- `eegmmidb/SUB_001_ANN_01.csv` — Annotations (label, duration, start/end rows)
|
||||||
|
|
||||||
|
The curated dataset excludes the 6 problematic subjects (S088, S089, S092, S100,
|
||||||
|
S104, S106) noted in the paper. Run numbering is offset by 2 from PhysioNet's
|
||||||
|
original (our Run 01 = PhysioNet R03).
|
||||||
|
|
||||||
|
## Usage
|
||||||
|
|
||||||
|
1. Install dependencies: `pip install -r requirements.txt`
|
||||||
|
2. Place curated CSV data in `eegmmidb/` directory
|
||||||
|
3. Edit `config.py` to set `ICA_STRATEGY` ('per_run', 'per_subject', or 'global')
|
||||||
|
4. Run `reproduction_notebook.ipynb`
|
||||||
|
|
||||||
|
## Key Implementation Decisions
|
||||||
|
|
||||||
|
| Decision | Paper says | We do | Rationale |
|
||||||
|
|----------|-----------|-------|----------|
|
||||||
|
| Pipeline order | Figure 1: filter→ICA; Algorithm 1: ICA with internal sub-band eval | ICA then sub-band eval | Energy criterion is meaningless on pre-filtered data |
|
||||||
|
| ICA algorithm | Gram-Schmidt (Algorithm 1, Step 3) | `algorithm='deflation'` | Deflation uses Gram-Schmidt |
|
||||||
|
| Energy criterion | ∀χ ∈ {α, β, **γ**} | ∀χ ∈ {**θ**, α, β} | γ never defined; likely typo for θ |
|
||||||
|
| ICA scope | Not specified | Configurable | Reproducibility variable |
|
||||||
|
| Classification Method | Methods 1 and 2 | Method 2 only (cross-subject) | Method 1 split is contradictory |
|
||||||
|
|
||||||
|
## Paper's Reported Results (Method 2, Set 3)
|
||||||
|
|
||||||
|
| Metric | Paper |
|
||||||
|
|--------|-------|
|
||||||
|
| Overall accuracy | 68.8 ± 0.71% |
|
||||||
|
| ME recall | 68.17% |
|
||||||
|
| MI recall | 68.41% |
|
||||||
144
config.py
Normal file
144
config.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""
|
||||||
|
Configuration for reproducing Mwata-Velu et al. (2023)
|
||||||
|
"EEG-BCI Features Discrimination between Executed and Imagined Movements
|
||||||
|
Based on FastICA, Hjorth Parameters, and SVM"
|
||||||
|
Mathematics 2023, 11, 4409. DOI: 10.3390/math11214409
|
||||||
|
|
||||||
|
Dataset: PhysioNet EEG Motor Movement/Imagery Dataset (curated CSV format)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Paths
|
||||||
|
# =============================================================================
|
||||||
|
DATA_DIR = Path("..\eegmmidb")
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Dataset parameters
|
||||||
|
# =============================================================================
|
||||||
|
SAMPLING_RATE = 160 # Hz
|
||||||
|
N_CHANNELS = 64
|
||||||
|
|
||||||
|
# Full 64-channel names (Sharbrough system, PhysioNet ordering)
|
||||||
|
CHANNEL_NAMES = [
|
||||||
|
'FC5', 'FC3', 'FC1', 'FCz', 'FC2', 'FC4', 'FC6',
|
||||||
|
'C5', 'C3', 'C1', 'Cz', 'C2', 'C4', 'C6',
|
||||||
|
'CP5', 'CP3', 'CP1', 'CPz', 'CP2', 'CP4', 'CP6',
|
||||||
|
'Fp1', 'Fpz', 'Fp2',
|
||||||
|
'AF7', 'AF3', 'AFz', 'AF4', 'AF8',
|
||||||
|
'F7', 'F5', 'F3', 'F1', 'Fz', 'F2', 'F4', 'F6', 'F8',
|
||||||
|
'FT7', 'FT8',
|
||||||
|
'T7', 'T8', 'T9', 'T10',
|
||||||
|
'TP7', 'TP8',
|
||||||
|
'P7', 'P5', 'P3', 'P1', 'Pz', 'P2', 'P4', 'P6', 'P8',
|
||||||
|
'PO7', 'PO3', 'POz', 'PO4', 'PO8',
|
||||||
|
'O1', 'Oz', 'O2',
|
||||||
|
'Iz',
|
||||||
|
]
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Channel selections (Section 3.2)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
# 19 channels from the 10-20 system used for ICA decomposition (Section 3.2)
|
||||||
|
ICA_CHANNELS = [
|
||||||
|
'Fp1', 'Fp2', 'F7', 'F3', 'Fz', 'F4', 'F8',
|
||||||
|
'T7', 'C3', 'Cz', 'C4', 'T8',
|
||||||
|
'P7', 'P3', 'Pz', 'P4', 'P8', 'O1', 'O2',
|
||||||
|
]
|
||||||
|
|
||||||
|
# 9 "Selected_channels" for ICA energy concentration criterion (Algorithm 1, Step 7)
|
||||||
|
# These are the sensorimotor + frontal + parietal channels the paper evaluates
|
||||||
|
# energy concentration against.
|
||||||
|
SELECTED_CHANNELS = ['C3', 'Cz', 'C4', 'F3', 'Fz', 'F4', 'P3', 'Pz', 'P4']
|
||||||
|
|
||||||
|
# Channels used for Hjorth feature extraction (Section 3.5, Table 5)
|
||||||
|
# The paper's best results (Set 3) use C3, Cz, C4.
|
||||||
|
TARGET_CHANNELS = ['C3', 'Cz', 'C4']
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Task / run definitions
|
||||||
|
# =============================================================================
|
||||||
|
# NOTE ON NUMBERING: The curated CSV dataset uses a different run numbering
|
||||||
|
# scheme than the original PhysioNet EDF files. The mapping is:
|
||||||
|
#
|
||||||
|
# Curated CSV PhysioNet EDF Task
|
||||||
|
# ----------- ------------- ----
|
||||||
|
# Run 01 R03 Execute open/close left or right fist
|
||||||
|
# Run 02 R04 Imagine open/close left or right fist
|
||||||
|
# Run 03 R05 Execute open/close both fists or both feet
|
||||||
|
# Run 04 R06 Imagine open/close both fists or both feet
|
||||||
|
# Run 05 R07 Execute open/close left or right fist
|
||||||
|
# Run 06 R08 Imagine open/close left or right fist
|
||||||
|
# Run 07 R09 Execute open/close both fists or both feet
|
||||||
|
# Run 08 R10 Imagine open/close both fists or both feet
|
||||||
|
# Run 09 R11 Execute open/close left or right fist
|
||||||
|
# Run 10 R12 Imagine open/close left or right fist
|
||||||
|
# Run 11 R13 Execute open/close both fists or both feet
|
||||||
|
# Run 12 R14 Imagine open/close both fists or both feet
|
||||||
|
#
|
||||||
|
# The paper's Section 4 states twice that results correspond to R03, R04, R07,
|
||||||
|
# R08, R11, R12 (left/right fist only). This agrees with another statement
|
||||||
|
# that says they only use 6 of the 14 runs per subject. However, the sample
|
||||||
|
# counts (8652 total) require including all 12 task runs. Additionally, the
|
||||||
|
# paper also says "samples of the first 10 runs constituted the training set;
|
||||||
|
# those of the 11th and 12th, and 13th and 14th runs were used as the testing
|
||||||
|
# and validation sets, respectively". These statements contradict each other.
|
||||||
|
# We use the 6 runs that are listed twice: R03, R04, R07, R08, R11, R12.
|
||||||
|
|
||||||
|
EXECUTION_RUNS = [1, 5, 9] # R03, R07, R11
|
||||||
|
IMAGERY_RUNS = [2, 6, 10] # R04, R08, R12
|
||||||
|
TARGET_RUNS = EXECUTION_RUNS + IMAGERY_RUNS
|
||||||
|
|
||||||
|
# Annotation labels that correspond to T1/T2 events (active task periods).
|
||||||
|
# T0 (rest) is excluded. These codes come from the curated CSV annotation files.
|
||||||
|
ACTIVE_EVENT_LABELS = [2, 3, 5, 6, 8, 9, 11, 12]
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Sub-band definitions (Section 3.3)
|
||||||
|
# =============================================================================
|
||||||
|
SUB_BANDS = [
|
||||||
|
('theta', 4.0, 8.0),
|
||||||
|
('alpha', 8.0, 13.0),
|
||||||
|
('beta', 13.0, 30.0),
|
||||||
|
]
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# ICA parameters (Section 3.4, Algorithm 1)
|
||||||
|
# =============================================================================
|
||||||
|
ICA_N_COMPONENTS = 19
|
||||||
|
ICA_ENERGY_THRESHOLD = 0.35
|
||||||
|
ICA_MAX_ITER = 500
|
||||||
|
ICA_TOL = 1e-4
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# SVM parameters (Section 3.6, Figure 6)
|
||||||
|
# =============================================================================
|
||||||
|
SVM_C = 2 ** 13 # 8192
|
||||||
|
SVM_GAMMA = 2 ** 1 # 2
|
||||||
|
SVM_KERNEL = 'rbf'
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Evaluation
|
||||||
|
# =============================================================================
|
||||||
|
N_RUNS = 5 # Paper: "results were averaged by running the model five times"
|
||||||
|
RANDOM_SEEDS = [42, 123, 456, 789, 1024]
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# ICA strategy (not specified in paper — this is a reproducibility variable)
|
||||||
|
# =============================================================================
|
||||||
|
# Options:
|
||||||
|
# 'per_run' — Fit ICA independently on each ~2-minute run
|
||||||
|
# 'per_subject' — Fit ICA once on all runs concatenated per subject
|
||||||
|
# 'global' — Fit ICA once on all training subjects concatenated
|
||||||
|
ICA_STRATEGY = 'per_subject'
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Method 2: Cross-Subject Split (Table 4)
|
||||||
|
# =============================================================================
|
||||||
|
# Note: The curated dataset already excludes the 6 problematic subjects
|
||||||
|
# (S088, S089, S092, S100, S104, S106), so we use consecutive IDs.
|
||||||
|
TRAIN_SUBJECTS = list(range(1, 84)) # Subjects 1-83
|
||||||
|
TEST_SUBJECTS = list(range(84, 94)) # Subjects 84-93
|
||||||
|
VAL_SUBJECTS = list(range(94, 104)) # Subjects 94-103
|
||||||
118
data_loading.py
Normal file
118
data_loading.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""
|
||||||
|
Data loading utilities for the curated PhysioNet EEG Motor Movement/Imagery
|
||||||
|
dataset (CSV format).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pandas as pd
|
||||||
|
import mne
|
||||||
|
|
||||||
|
from config import DATA_DIR, N_CHANNELS, CHANNEL_NAMES, SAMPLING_RATE
|
||||||
|
|
||||||
|
|
||||||
|
def load_signal_file(subject_num, run_num, data_dir=DATA_DIR):
|
||||||
|
"""
|
||||||
|
Load EEG signal from CSV file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
np.ndarray of shape (n_samples, N_CHANNELS)
|
||||||
|
"""
|
||||||
|
filename = f"SUB_{subject_num:03d}_SIG_{run_num:02d}.csv"
|
||||||
|
filepath = data_dir / filename
|
||||||
|
|
||||||
|
if not filepath.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {filepath}")
|
||||||
|
|
||||||
|
data = pd.read_csv(filepath, header=None).values
|
||||||
|
|
||||||
|
if data.shape[1] != N_CHANNELS:
|
||||||
|
raise ValueError(f"Expected {N_CHANNELS} channels, got {data.shape[1]}")
|
||||||
|
|
||||||
|
return data
|
||||||
|
|
||||||
|
|
||||||
|
def load_annotation_file(subject_num, run_num, data_dir=DATA_DIR):
|
||||||
|
"""
|
||||||
|
Load annotations from CSV file.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
pd.DataFrame with columns ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
|
||||||
|
"""
|
||||||
|
filename = f"SUB_{subject_num:03d}_ANN_{run_num:02d}.csv"
|
||||||
|
filepath = data_dir / filename
|
||||||
|
|
||||||
|
if not filepath.exists():
|
||||||
|
raise FileNotFoundError(f"File not found: {filepath}")
|
||||||
|
|
||||||
|
annotations = pd.read_csv(filepath)
|
||||||
|
|
||||||
|
expected_cols = ['label', 'duration_sec', 'n_rows', 'start_row', 'end_row']
|
||||||
|
if list(annotations.columns) != expected_cols:
|
||||||
|
if len(annotations.columns) == 5:
|
||||||
|
annotations.columns = expected_cols
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Expected 5 columns {expected_cols}, got {len(annotations.columns)}: "
|
||||||
|
f"{list(annotations.columns)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return annotations
|
||||||
|
|
||||||
|
|
||||||
|
def create_mne_raw(signal_data, sfreq=SAMPLING_RATE):
|
||||||
|
"""
|
||||||
|
Create MNE Raw object from signal data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal_data : np.ndarray of shape (n_samples, N_CHANNELS)
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
mne.io.RawArray
|
||||||
|
"""
|
||||||
|
# Transpose to (channels, samples) and convert µV → V
|
||||||
|
data = signal_data.T * 1e-6
|
||||||
|
|
||||||
|
info = mne.create_info(
|
||||||
|
ch_names=CHANNEL_NAMES,
|
||||||
|
sfreq=sfreq,
|
||||||
|
ch_types='eeg',
|
||||||
|
)
|
||||||
|
raw = mne.io.RawArray(data, info, verbose=False)
|
||||||
|
|
||||||
|
# Set standard 10-20 montage
|
||||||
|
montage = mne.channels.make_standard_montage('standard_1005')
|
||||||
|
# Only set channels that exist in the montage
|
||||||
|
valid_chs = [ch for ch in raw.ch_names if ch in montage.ch_names]
|
||||||
|
raw = raw.pick(valid_chs)
|
||||||
|
raw.set_montage(montage, on_missing='ignore')
|
||||||
|
|
||||||
|
return raw
|
||||||
|
|
||||||
|
|
||||||
|
def add_annotations_to_raw(raw, annotations_df):
|
||||||
|
"""
|
||||||
|
Add event annotations to MNE Raw object.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw : mne.io.Raw
|
||||||
|
annotations_df : pd.DataFrame with start_row, duration_sec, label columns
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
mne.io.Raw with annotations added
|
||||||
|
"""
|
||||||
|
sfreq = raw.info['sfreq']
|
||||||
|
|
||||||
|
onsets = annotations_df['start_row'].values / sfreq
|
||||||
|
durations = annotations_df['duration_sec'].values
|
||||||
|
descriptions = [str(int(label)) for label in annotations_df['label'].values]
|
||||||
|
|
||||||
|
annot = mne.Annotations(onset=onsets, duration=durations,
|
||||||
|
description=descriptions)
|
||||||
|
raw.set_annotations(annot)
|
||||||
|
return raw
|
||||||
910
pipeline.py
Normal file
910
pipeline.py
Normal file
@@ -0,0 +1,910 @@
|
|||||||
|
"""
|
||||||
|
Core signal processing pipeline for reproducing Mwata-Velu et al. (2023).
|
||||||
|
|
||||||
|
Contains:
|
||||||
|
- Chebyshev Type II sub-band filtering (Section 3.3)
|
||||||
|
- FastICA denoising with energy concentration criterion (Section 3.4, Algorithm 1)
|
||||||
|
- Hjorth parameter feature extraction (Section 3.5)
|
||||||
|
- SVM classification with 5-run averaging (Section 3.6)
|
||||||
|
|
||||||
|
PIPELINE ORDER (Algorithm 1 interpretation):
|
||||||
|
1. ICA on broadband 19-channel data
|
||||||
|
- Step 6 internally filters each back-projected IC into θ/α/β sub-bands
|
||||||
|
to evaluate energy concentration in selected channels
|
||||||
|
- Step 7 keeps only ICs meeting the threshold in ANY sub-band (OR logic).
|
||||||
|
The paper specifies ALL (AND logic), but this causes zero components
|
||||||
|
to pass for many subjects. OR logic is used as a documented deviation.
|
||||||
|
- Step 8 reconstructs cleaned broadband signal from kept ICs
|
||||||
|
2. Filter the *cleaned* broadband signal into θ/α/β for Hjorth Activity features
|
||||||
|
3. Compute Hjorth Mobility/Complexity on the *cleaned* unfiltered signal
|
||||||
|
|
||||||
|
This resolves the paper's contradiction between Figure 1 and Algorithm 1. The
|
||||||
|
Algorithm 1 reading is preferred because the energy concentration criterion is
|
||||||
|
meaningless on pre-filtered data (e.g., filtering theta-band data into alpha
|
||||||
|
yields ~zero).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
from scipy.signal import cheby2, filtfilt
|
||||||
|
import mne
|
||||||
|
from mne.preprocessing import ICA
|
||||||
|
from sklearn.svm import SVC
|
||||||
|
from sklearn.preprocessing import MinMaxScaler
|
||||||
|
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
|
||||||
|
|
||||||
|
from config import (
|
||||||
|
ICA_CHANNELS, SELECTED_CHANNELS, SUB_BANDS, TARGET_RUNS,
|
||||||
|
ICA_N_COMPONENTS, ICA_ENERGY_THRESHOLD, ICA_MAX_ITER, ICA_TOL,
|
||||||
|
TARGET_CHANNELS, EXECUTION_RUNS, IMAGERY_RUNS, ACTIVE_EVENT_LABELS,
|
||||||
|
SVM_C, SVM_GAMMA, SVM_KERNEL, N_RUNS, RANDOM_SEEDS,
|
||||||
|
)
|
||||||
|
from data_loading import (
|
||||||
|
load_signal_file, load_annotation_file, create_mne_raw, add_annotations_to_raw,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Sub-band filtering (Section 3.3)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def design_chebyshev_filter(lowcut, highcut, fs, order=5, rs=34):
|
||||||
|
"""
|
||||||
|
Design Chebyshev Type II bandpass filter.
|
||||||
|
|
||||||
|
Paper specifies:
|
||||||
|
- 5th-order Chebyshev Type II
|
||||||
|
- Stop-band attenuation of -34 dB
|
||||||
|
- Transition bands reaching 80% of the gain
|
||||||
|
"""
|
||||||
|
nyq = 0.5 * fs
|
||||||
|
low = lowcut / nyq
|
||||||
|
high = highcut / nyq
|
||||||
|
b, a = cheby2(order, rs, [low, high], btype='band')
|
||||||
|
return b, a
|
||||||
|
|
||||||
|
|
||||||
|
def apply_bandpass_filter(raw, lowcut, highcut, picks=None):
|
||||||
|
"""
|
||||||
|
Apply Chebyshev Type II bandpass filter.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw : mne.io.Raw
|
||||||
|
lowcut, highcut : float
|
||||||
|
picks : list of str, optional
|
||||||
|
Channel names to filter. If None, filters all channels.
|
||||||
|
When specified, returns a Raw containing only those channels.
|
||||||
|
|
||||||
|
Returns a new Raw object with filtered data.
|
||||||
|
"""
|
||||||
|
if picks is not None:
|
||||||
|
raw_sub = raw.copy().pick(picks)
|
||||||
|
else:
|
||||||
|
raw_sub = raw.copy()
|
||||||
|
|
||||||
|
b, a = design_chebyshev_filter(lowcut, highcut, raw_sub.info['sfreq'])
|
||||||
|
|
||||||
|
data = raw_sub.get_data()
|
||||||
|
filtered_data = filtfilt(b, a, data, axis=1)
|
||||||
|
|
||||||
|
raw_filtered = mne.io.RawArray(filtered_data, raw_sub.info, verbose=False)
|
||||||
|
raw_filtered.set_annotations(raw.annotations)
|
||||||
|
return raw_filtered
|
||||||
|
|
||||||
|
|
||||||
|
def _cheby2_bandpass(data, sfreq, low, high, order=5, rs=34.0):
|
||||||
|
"""
|
||||||
|
Apply Chebyshev Type II bandpass filter to a numpy array.
|
||||||
|
|
||||||
|
Used internally by the ICA energy concentration criterion (Step 6-7).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data : ndarray, shape (n_signals, n_samples)
|
||||||
|
sfreq : float
|
||||||
|
low, high : float, pass-band edges in Hz
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
filtered : ndarray, same shape as data
|
||||||
|
"""
|
||||||
|
b, a = design_chebyshev_filter(low, high, sfreq, order=order, rs=rs)
|
||||||
|
return filtfilt(b, a, data, axis=-1)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# FastICA denoising (Section 3.4, Algorithm 1)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def apply_fastica_denoising(raw, n_components=ICA_N_COMPONENTS,
|
||||||
|
energy_threshold=ICA_ENERGY_THRESHOLD,
|
||||||
|
verbose=True):
|
||||||
|
"""
|
||||||
|
Denoise EEG with FastICA following Algorithm 1 of Mwata-Velu et al. (2023).
|
||||||
|
|
||||||
|
This function operates on BROADBAND data. Sub-band filtering is used
|
||||||
|
internally only for the energy concentration criterion (Steps 6-7).
|
||||||
|
|
||||||
|
Steps mapped to the paper
|
||||||
|
-------------------------
|
||||||
|
Steps 1-5: FastICA decomposition of 19 channels into n_components ICs.
|
||||||
|
- Step 1: Random weight vector initialization
|
||||||
|
- Step 2: Fixed-point update with log-cosh non-linearity (Eq. 3-4)
|
||||||
|
- Step 3: Gram-Schmidt orthogonalization (deflation approach)
|
||||||
|
- Step 4: Weight vector normalization
|
||||||
|
- Step 5: Convergence check; repeat 2-4 if not converged
|
||||||
|
|
||||||
|
NOTE: We use algorithm='deflation' to match the paper's explicit
|
||||||
|
description of Gram-Schmidt orthogonalization. sklearn's default
|
||||||
|
('parallel') uses symmetric decorrelation instead, which is a
|
||||||
|
different algorithm. See Hyvärinen & Oja (1997, 2000).
|
||||||
|
|
||||||
|
Step 6: For every IC k, back-project to 19 channels, filter into each
|
||||||
|
sub-band, and compute per-channel variance.
|
||||||
|
|
||||||
|
Step 7: Retain IC k if, for ANY sub-band, the ratio
|
||||||
|
sum(Selected) / sum(All) >= threshold (paper: 0.35).
|
||||||
|
|
||||||
|
NOTE: The paper specifies AND logic (all sub-bands must pass),
|
||||||
|
but this is too strict — many subjects yield zero kept components.
|
||||||
|
We use OR logic (any sub-band passing suffices) as a documented
|
||||||
|
reproducibility deviation.
|
||||||
|
|
||||||
|
NOTE: The paper writes {γ, α, β}" with gamma instead of theta.
|
||||||
|
The gamma band is never defined in the paper. We interpret this as a
|
||||||
|
typo and use θ/α/β as defined in Section 3.3.
|
||||||
|
|
||||||
|
Step 8: Reconstruct û = A_p @ S_p using only kept components.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw : mne.io.Raw
|
||||||
|
Must contain (at least) the 19 ICA_CHANNELS.
|
||||||
|
Should be BROADBAND (unfiltered) data.
|
||||||
|
n_components : int
|
||||||
|
Number of ICs to extract. Paper uses 19.
|
||||||
|
energy_threshold : float
|
||||||
|
Per-sub-band concentration ratio threshold. Paper: 0.35.
|
||||||
|
verbose : bool
|
||||||
|
Print progress information.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
raw_clean : mne.io.Raw
|
||||||
|
Copy of raw with denoised ICA channels (still broadband).
|
||||||
|
ica : mne.preprocessing.ICA
|
||||||
|
Fitted ICA object.
|
||||||
|
kept_components : list of int
|
||||||
|
Indices of ICs that passed the Step 7 criterion.
|
||||||
|
"""
|
||||||
|
sfreq = raw.info['sfreq']
|
||||||
|
if verbose:
|
||||||
|
print(f" [FastICA] {n_components} components, "
|
||||||
|
f"threshold={energy_threshold}, deflation mode")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Steps 1-5: FastICA decomposition on broadband data
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
raw_ica = raw.copy().pick(ICA_CHANNELS)
|
||||||
|
|
||||||
|
ica = ICA(
|
||||||
|
n_components=n_components,
|
||||||
|
method='fastica',
|
||||||
|
fit_params={
|
||||||
|
'max_iter': ICA_MAX_ITER,
|
||||||
|
'tol': ICA_TOL,
|
||||||
|
'algorithm': 'deflation', # Paper's Algorithm 1, Step 3
|
||||||
|
},
|
||||||
|
random_state=42,
|
||||||
|
)
|
||||||
|
ica.fit(raw_ica)
|
||||||
|
|
||||||
|
# Mixing matrix A and IC time-series S
|
||||||
|
mixing = ica.mixing_matrix_
|
||||||
|
unmixing = ica.unmixing_matrix_
|
||||||
|
raw_data = raw_ica.get_data()
|
||||||
|
icasigs = unmixing @ raw_data
|
||||||
|
|
||||||
|
# Channel index maps
|
||||||
|
all_ch_names = list(raw_ica.ch_names)
|
||||||
|
sel_idx = [all_ch_names.index(ch) for ch in SELECTED_CHANNELS]
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 6: Evaluate each IC by sub-band (internal filtering)
|
||||||
|
#
|
||||||
|
# For each IC k:
|
||||||
|
# 1. Back-project to 19 channels: signal[c, t] = mixing[c, k] * IC_k[t]
|
||||||
|
# 2. Filter the 19-channel signal into each sub-band
|
||||||
|
# 3. Compute variance per channel
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
ic_channel_var = {name: [] for name, _, _ in SUB_BANDS}
|
||||||
|
|
||||||
|
for k in range(n_components):
|
||||||
|
signal_19ch = mixing[:, k, np.newaxis] * icasigs[k, :] # (19, T)
|
||||||
|
|
||||||
|
for name, low, high in SUB_BANDS:
|
||||||
|
filtered = _cheby2_bandpass(signal_19ch, sfreq, low, high)
|
||||||
|
var_per_channel = np.var(filtered, axis=1, ddof=1)
|
||||||
|
ic_channel_var[name].append(var_per_channel)
|
||||||
|
|
||||||
|
for name in ic_channel_var:
|
||||||
|
ic_channel_var[name] = np.array(ic_channel_var[name])
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 7: Energy concentration criterion
|
||||||
|
#
|
||||||
|
# Paper states: ∀χ ∈ {θ, α, β} (AND logic).
|
||||||
|
# However, AND across all three bands is extremely strict and causes
|
||||||
|
# zero components to pass for many subjects. We use OR logic instead:
|
||||||
|
# keep IC k if it passes the threshold in ANY sub-band. This is more
|
||||||
|
# defensible — a component concentrated in motor cortex in alpha is
|
||||||
|
# still a neural signal worth keeping even if its theta energy is
|
||||||
|
# diffuse. Documented as a reproducibility deviation.
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
kept_components = []
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" {'IC':>4} | {'theta':>7} | {'alpha':>7} | {'beta':>7} | {'kept?':>5}")
|
||||||
|
print(f" {'-'*46}")
|
||||||
|
|
||||||
|
for k in range(n_components):
|
||||||
|
ratios = {}
|
||||||
|
keep = False # OR logic: start False, flip to True if ANY passes
|
||||||
|
# keep = True # AND logic: start True, flip to False if ANY fails
|
||||||
|
|
||||||
|
for name, _, _ in SUB_BANDS:
|
||||||
|
var_all = ic_channel_var[name][k]
|
||||||
|
num = var_all[sel_idx].sum()
|
||||||
|
den = var_all.sum()
|
||||||
|
ratio = num / den if den > 0 else 0.0
|
||||||
|
ratios[name] = ratio
|
||||||
|
if ratio >= energy_threshold:
|
||||||
|
keep = True # passes in at least one sub-band
|
||||||
|
# if ratio < energy_threshold:
|
||||||
|
# keep = False # fails in at least one sub-band
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
status = "YES" if keep else "no"
|
||||||
|
print(f" {k:>4} | {ratios['theta']:>7.3f} | "
|
||||||
|
f"{ratios['alpha']:>7.3f} | {ratios['beta']:>7.3f} | {status:>5}")
|
||||||
|
|
||||||
|
if keep:
|
||||||
|
kept_components.append(k)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" [FastICA] Kept {len(kept_components)}/{n_components} components")
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Step 8: Reconstruct û = A_p @ S_p (broadband cleaned signal)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
if len(kept_components) == 0:
|
||||||
|
if verbose:
|
||||||
|
print(" [FastICA] WARNING: no components passed threshold "
|
||||||
|
"even with OR logic. Returning original data unchanged.")
|
||||||
|
return raw.copy(), ica, kept_components
|
||||||
|
|
||||||
|
A_p = mixing[:, kept_components]
|
||||||
|
S_p = icasigs[kept_components, :]
|
||||||
|
reconstructed = A_p @ S_p
|
||||||
|
|
||||||
|
raw_clean = mne.io.RawArray(reconstructed, raw_ica.info, verbose=False)
|
||||||
|
raw_clean.set_annotations(raw.annotations)
|
||||||
|
return raw_clean, ica, kept_components
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Hjorth parameter extraction (Section 3.5)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def compute_hjorth_parameters(signal_array):
|
||||||
|
"""
|
||||||
|
Compute Hjorth parameters: Activity, Mobility, Complexity.
|
||||||
|
|
||||||
|
Activity: Variance of the signal (Eq. 5)
|
||||||
|
Mobility: sqrt(Var(d/dt signal) / Var(signal)) (Eq. 6)
|
||||||
|
Complexity: Mobility(d/dt signal) / Mobility(signal) (Eq. 7)
|
||||||
|
|
||||||
|
The derivative d/dt u(i) is approximated as u(i) - u(i-1) per the paper.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
signal_array : np.ndarray, 1D
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
activity, mobility, complexity : float
|
||||||
|
"""
|
||||||
|
activity = np.var(signal_array)
|
||||||
|
|
||||||
|
d1 = np.diff(signal_array)
|
||||||
|
var_d1 = np.var(d1)
|
||||||
|
|
||||||
|
if activity > 0:
|
||||||
|
mobility = np.sqrt(var_d1 / activity)
|
||||||
|
else:
|
||||||
|
mobility = 0.0
|
||||||
|
|
||||||
|
d2 = np.diff(d1)
|
||||||
|
var_d2 = np.var(d2)
|
||||||
|
|
||||||
|
if var_d1 > 0 and mobility > 0:
|
||||||
|
mobility_d1 = np.sqrt(var_d2 / var_d1)
|
||||||
|
complexity = mobility_d1 / mobility
|
||||||
|
else:
|
||||||
|
complexity = 0.0
|
||||||
|
|
||||||
|
return activity, mobility, complexity
|
||||||
|
|
||||||
|
|
||||||
|
def extract_epoch_features(data_raw, data_theta, data_alpha, data_beta,
|
||||||
|
channels=TARGET_CHANNELS):
|
||||||
|
"""
|
||||||
|
Extract Hjorth features for one epoch across all target channels.
|
||||||
|
|
||||||
|
Feature set matches Table 5, Set 3: Aθ, Aα, Aβ, Mraw, Craw per channel.
|
||||||
|
Total features = len(channels) * 5.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data_raw, data_theta, data_alpha, data_beta : np.ndarray
|
||||||
|
Shape (n_channels, n_samples) epoch data for each band.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
features : np.ndarray of shape (n_channels * 5,)
|
||||||
|
"""
|
||||||
|
features = []
|
||||||
|
for i in range(len(channels)):
|
||||||
|
a_theta, _, _ = compute_hjorth_parameters(data_theta[i])
|
||||||
|
a_alpha, _, _ = compute_hjorth_parameters(data_alpha[i])
|
||||||
|
a_beta, _, _ = compute_hjorth_parameters(data_beta[i])
|
||||||
|
_, m_raw, c_raw = compute_hjorth_parameters(data_raw[i])
|
||||||
|
features.extend([a_theta, a_alpha, a_beta, m_raw, c_raw])
|
||||||
|
return np.array(features)
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Subject processing (orchestration)
|
||||||
|
#
|
||||||
|
# In all strategies below, the pipeline is:
|
||||||
|
# 1. ICA on broadband data (1 fit, not 4)
|
||||||
|
# 2. Filter the CLEANED broadband signal into θ/α/β for Activity features
|
||||||
|
# 3. Compute Mobility/Complexity on the CLEANED unfiltered signal
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def process_subject_per_run(subject_num, target_runs=None, apply_ica=True,
|
||||||
|
verbose=True):
|
||||||
|
"""
|
||||||
|
Process one subject with ICA fitted independently per run.
|
||||||
|
|
||||||
|
This gives ICA the least amount of data to work with (~19,200 samples for
|
||||||
|
a 2-minute run on 19 channels).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
all_features : list of np.ndarray, one per run
|
||||||
|
all_labels : list of np.ndarray, one per run
|
||||||
|
run_numbers : list of int
|
||||||
|
total_kept : int (total across all ICA fits)
|
||||||
|
"""
|
||||||
|
if target_runs is None:
|
||||||
|
target_runs = TARGET_RUNS
|
||||||
|
|
||||||
|
all_features = []
|
||||||
|
all_labels = []
|
||||||
|
run_numbers = []
|
||||||
|
total_kept = 0
|
||||||
|
n_ica_fits = 0
|
||||||
|
|
||||||
|
for run_num in target_runs:
|
||||||
|
try:
|
||||||
|
signal_data = load_signal_file(subject_num, run_num)
|
||||||
|
annotations = load_annotation_file(subject_num, run_num)
|
||||||
|
except Exception as e:
|
||||||
|
if verbose:
|
||||||
|
print(f" Could not load S{subject_num:03d} R{run_num:02d}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
raw = create_mne_raw(signal_data)
|
||||||
|
raw = add_annotations_to_raw(raw, annotations)
|
||||||
|
|
||||||
|
# ICA on broadband data (1 fit per run; sub-band eval is internal)
|
||||||
|
if apply_ica:
|
||||||
|
raw_clean, _, kept = apply_fastica_denoising(raw, verbose=verbose)
|
||||||
|
total_kept += len(kept)
|
||||||
|
n_ica_fits += 1
|
||||||
|
else:
|
||||||
|
raw_clean = raw
|
||||||
|
|
||||||
|
# Filter the CLEANED signal into sub-bands for Activity features
|
||||||
|
# Only filter the channels we actually need for feature extraction
|
||||||
|
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
|
||||||
|
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
|
||||||
|
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
|
||||||
|
|
||||||
|
# Determine task type from run number
|
||||||
|
if run_num in EXECUTION_RUNS:
|
||||||
|
task_type = 'ME'
|
||||||
|
elif run_num in IMAGERY_RUNS:
|
||||||
|
task_type = 'MI'
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Extract features from active events
|
||||||
|
features_list, labels_list = _extract_run_features(
|
||||||
|
raw_clean, raw_theta, raw_alpha, raw_beta,
|
||||||
|
annotations, task_type
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(features_list) > 0:
|
||||||
|
all_features.append(np.array(features_list))
|
||||||
|
all_labels.append(np.array(labels_list))
|
||||||
|
run_numbers.append(run_num)
|
||||||
|
|
||||||
|
return all_features, all_labels, run_numbers, total_kept
|
||||||
|
|
||||||
|
|
||||||
|
def process_subject_per_subject(subject_num, target_runs=None, apply_ica=True,
|
||||||
|
verbose=True):
|
||||||
|
"""
|
||||||
|
Process one subject with ICA fitted once on all runs concatenated.
|
||||||
|
|
||||||
|
This gives ICA more data for better decomposition (~115,200 samples
|
||||||
|
for 12 runs).
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
all_features : list of np.ndarray, one per run
|
||||||
|
all_labels : list of np.ndarray, one per run
|
||||||
|
run_numbers : list of int
|
||||||
|
n_kept_components : int
|
||||||
|
"""
|
||||||
|
if target_runs is None:
|
||||||
|
target_runs = TARGET_RUNS
|
||||||
|
|
||||||
|
# Step 1: Load and concatenate all runs
|
||||||
|
raw_runs = []
|
||||||
|
annotations_runs = []
|
||||||
|
run_lengths = []
|
||||||
|
|
||||||
|
for run_num in target_runs:
|
||||||
|
try:
|
||||||
|
signal_data = load_signal_file(subject_num, run_num)
|
||||||
|
annotations = load_annotation_file(subject_num, run_num)
|
||||||
|
raw = create_mne_raw(signal_data)
|
||||||
|
raw = add_annotations_to_raw(raw, annotations)
|
||||||
|
raw_runs.append(raw)
|
||||||
|
annotations_runs.append(annotations)
|
||||||
|
run_lengths.append(raw.n_times)
|
||||||
|
except Exception as e:
|
||||||
|
if verbose:
|
||||||
|
print(f" Could not load S{subject_num:03d} R{run_num:02d}: {e}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(raw_runs) == 0:
|
||||||
|
return [], [], [], 0
|
||||||
|
|
||||||
|
raw_concat = mne.concatenate_raws(raw_runs, preload=True)
|
||||||
|
if verbose:
|
||||||
|
duration = raw_concat.n_times / raw_concat.info['sfreq']
|
||||||
|
print(f" S{subject_num:03d}: {len(raw_runs)} runs, {duration:.0f}s")
|
||||||
|
|
||||||
|
# Step 2: ICA once on broadband concatenated data
|
||||||
|
n_kept = 0
|
||||||
|
if apply_ica:
|
||||||
|
raw_clean, _, kept = apply_fastica_denoising(raw_concat, verbose=verbose)
|
||||||
|
n_kept = len(kept)
|
||||||
|
else:
|
||||||
|
raw_clean = raw_concat
|
||||||
|
|
||||||
|
# Step 3: Filter the CLEANED signal into sub-bands for Activity features
|
||||||
|
# Only filter the channels we actually need for feature extraction
|
||||||
|
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
|
||||||
|
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
|
||||||
|
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
|
||||||
|
|
||||||
|
# Step 4: Extract features per run using global sample offsets
|
||||||
|
all_raw_data = raw_clean.get_data()
|
||||||
|
all_theta_data = raw_theta.get_data()
|
||||||
|
all_alpha_data = raw_alpha.get_data()
|
||||||
|
all_beta_data = raw_beta.get_data()
|
||||||
|
|
||||||
|
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
|
||||||
|
all_features = []
|
||||||
|
all_labels = []
|
||||||
|
run_numbers = []
|
||||||
|
sample_offset = 0
|
||||||
|
|
||||||
|
for run_idx, run_num in enumerate(target_runs):
|
||||||
|
if run_idx >= len(raw_runs):
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotations = annotations_runs[run_idx]
|
||||||
|
run_length = run_lengths[run_idx]
|
||||||
|
|
||||||
|
if run_num in EXECUTION_RUNS:
|
||||||
|
task_type = 'ME'
|
||||||
|
elif run_num in IMAGERY_RUNS:
|
||||||
|
task_type = 'MI'
|
||||||
|
else:
|
||||||
|
sample_offset += run_length
|
||||||
|
continue
|
||||||
|
|
||||||
|
features_list, labels_list = _extract_run_features_from_arrays(
|
||||||
|
all_raw_data, all_theta_data, all_alpha_data, all_beta_data,
|
||||||
|
ch_idx_clean, ch_idx_theta, ch_idx_alpha, ch_idx_beta,
|
||||||
|
annotations, task_type, sample_offset, all_raw_data.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(features_list) > 0:
|
||||||
|
all_features.append(np.array(features_list))
|
||||||
|
all_labels.append(np.array(labels_list))
|
||||||
|
run_numbers.append(run_num)
|
||||||
|
|
||||||
|
sample_offset += run_length
|
||||||
|
|
||||||
|
return all_features, all_labels, run_numbers, n_kept
|
||||||
|
|
||||||
|
|
||||||
|
def process_subject_global(subject_num, global_ica_results,
|
||||||
|
target_runs=None, verbose=True):
|
||||||
|
"""
|
||||||
|
Process one subject using a pre-fitted global ICA.
|
||||||
|
|
||||||
|
The global ICA is fitted once on all training subjects' broadband data.
|
||||||
|
Each subject's data is then projected through the same unmixing matrix.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
subject_num : int
|
||||||
|
global_ica_results : dict
|
||||||
|
Keys: 'mixing', 'unmixing', 'kept_components'
|
||||||
|
From fit_global_ica().
|
||||||
|
target_runs : list of int, optional
|
||||||
|
verbose : bool
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
all_features, all_labels, run_numbers, n_kept
|
||||||
|
"""
|
||||||
|
if target_runs is None:
|
||||||
|
target_runs = TARGET_RUNS
|
||||||
|
|
||||||
|
# Load and concatenate
|
||||||
|
raw_runs = []
|
||||||
|
annotations_runs = []
|
||||||
|
run_lengths = []
|
||||||
|
|
||||||
|
for run_num in target_runs:
|
||||||
|
try:
|
||||||
|
signal_data = load_signal_file(subject_num, run_num)
|
||||||
|
annotations = load_annotation_file(subject_num, run_num)
|
||||||
|
raw = create_mne_raw(signal_data)
|
||||||
|
raw = add_annotations_to_raw(raw, annotations)
|
||||||
|
raw_runs.append(raw)
|
||||||
|
annotations_runs.append(annotations)
|
||||||
|
run_lengths.append(raw.n_times)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(raw_runs) == 0:
|
||||||
|
return [], [], [], 0
|
||||||
|
|
||||||
|
raw_concat = mne.concatenate_raws(raw_runs, preload=True)
|
||||||
|
|
||||||
|
# Apply pre-fitted global ICA on broadband data
|
||||||
|
ica_tuple = (global_ica_results['mixing'],
|
||||||
|
global_ica_results['unmixing'],
|
||||||
|
global_ica_results['kept_components'])
|
||||||
|
raw_clean = _apply_precomputed_ica(raw_concat, ica_tuple)
|
||||||
|
n_kept = len(global_ica_results['kept_components'])
|
||||||
|
|
||||||
|
# Filter the CLEANED signal into sub-bands for Activity features
|
||||||
|
# Only filter the channels we actually need for feature extraction
|
||||||
|
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
|
||||||
|
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
|
||||||
|
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
|
||||||
|
|
||||||
|
# Extract features per run
|
||||||
|
all_raw_data = raw_clean.get_data()
|
||||||
|
all_theta_data = raw_theta.get_data()
|
||||||
|
all_alpha_data = raw_alpha.get_data()
|
||||||
|
all_beta_data = raw_beta.get_data()
|
||||||
|
|
||||||
|
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
|
||||||
|
all_features = []
|
||||||
|
all_labels = []
|
||||||
|
run_numbers = []
|
||||||
|
sample_offset = 0
|
||||||
|
|
||||||
|
for run_idx, run_num in enumerate(target_runs):
|
||||||
|
if run_idx >= len(raw_runs):
|
||||||
|
continue
|
||||||
|
|
||||||
|
annotations = annotations_runs[run_idx]
|
||||||
|
run_length = run_lengths[run_idx]
|
||||||
|
|
||||||
|
if run_num in EXECUTION_RUNS:
|
||||||
|
task_type = 'ME'
|
||||||
|
elif run_num in IMAGERY_RUNS:
|
||||||
|
task_type = 'MI'
|
||||||
|
else:
|
||||||
|
sample_offset += run_length
|
||||||
|
continue
|
||||||
|
|
||||||
|
features_list, labels_list = _extract_run_features_from_arrays(
|
||||||
|
all_raw_data, all_theta_data, all_alpha_data, all_beta_data,
|
||||||
|
ch_idx_clean, ch_idx_theta, ch_idx_alpha, ch_idx_beta,
|
||||||
|
annotations, task_type, sample_offset, all_raw_data.shape[1]
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(features_list) > 0:
|
||||||
|
all_features.append(np.array(features_list))
|
||||||
|
all_labels.append(np.array(labels_list))
|
||||||
|
run_numbers.append(run_num)
|
||||||
|
|
||||||
|
sample_offset += run_length
|
||||||
|
|
||||||
|
return all_features, all_labels, run_numbers, n_kept
|
||||||
|
|
||||||
|
|
||||||
|
def fit_global_ica(subject_list, target_runs=None, verbose=True):
|
||||||
|
"""
|
||||||
|
Fit ICA once on all subjects' broadband data concatenated.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
global_ica_results : dict with keys 'mixing', 'unmixing', 'kept_components'
|
||||||
|
"""
|
||||||
|
if target_runs is None:
|
||||||
|
target_runs = TARGET_RUNS
|
||||||
|
|
||||||
|
all_raws = []
|
||||||
|
|
||||||
|
for subject_num in subject_list:
|
||||||
|
for run_num in target_runs:
|
||||||
|
try:
|
||||||
|
signal_data = load_signal_file(subject_num, run_num)
|
||||||
|
raw = create_mne_raw(signal_data)
|
||||||
|
all_raws.append(raw)
|
||||||
|
except Exception:
|
||||||
|
continue
|
||||||
|
|
||||||
|
if len(all_raws) == 0:
|
||||||
|
raise ValueError("No data loaded for global ICA")
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" Global ICA: concatenating {len(all_raws)} runs...")
|
||||||
|
|
||||||
|
global_raw = mne.concatenate_raws(all_raws, preload=True)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
duration = global_raw.n_times / global_raw.info['sfreq']
|
||||||
|
print(f" Global ICA: {duration:.0f}s total data")
|
||||||
|
|
||||||
|
# Fit ICA once on broadband data
|
||||||
|
if verbose:
|
||||||
|
print(f" Fitting global ICA on broadband data...")
|
||||||
|
_, ica, kept = apply_fastica_denoising(global_raw, verbose=verbose)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'mixing': ica.mixing_matrix_,
|
||||||
|
'unmixing': ica.unmixing_matrix_,
|
||||||
|
'kept_components': kept,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_precomputed_ica(raw, ica_result):
|
||||||
|
"""
|
||||||
|
Apply a pre-fitted ICA decomposition to reconstruct cleaned signals.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
raw : mne.io.Raw
|
||||||
|
ica_result : tuple (mixing, unmixing, kept_components)
|
||||||
|
"""
|
||||||
|
mixing, unmixing, kept_components = ica_result
|
||||||
|
|
||||||
|
if len(kept_components) == 0:
|
||||||
|
return raw.copy()
|
||||||
|
|
||||||
|
raw_ica = raw.copy().pick(ICA_CHANNELS)
|
||||||
|
raw_data = raw_ica.get_data()
|
||||||
|
icasigs = unmixing @ raw_data
|
||||||
|
|
||||||
|
A_p = mixing[:, kept_components]
|
||||||
|
S_p = icasigs[kept_components, :]
|
||||||
|
reconstructed = A_p @ S_p
|
||||||
|
|
||||||
|
raw_clean = mne.io.RawArray(reconstructed, raw_ica.info, verbose=False)
|
||||||
|
raw_clean.set_annotations(raw.annotations)
|
||||||
|
return raw_clean
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Feature extraction helpers
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def _extract_run_features(raw_clean, raw_theta, raw_alpha, raw_beta,
|
||||||
|
annotations, task_type):
|
||||||
|
"""Extract Hjorth features from a single (non-concatenated) run."""
|
||||||
|
features_list = []
|
||||||
|
labels_list = []
|
||||||
|
|
||||||
|
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
|
||||||
|
|
||||||
|
# Pre-fetch full arrays ONCE (avoid repeated get_data() in loop)
|
||||||
|
all_raw = raw_clean.get_data()
|
||||||
|
all_theta = raw_theta.get_data()
|
||||||
|
all_alpha = raw_alpha.get_data()
|
||||||
|
all_beta = raw_beta.get_data()
|
||||||
|
|
||||||
|
for _, annot in annotations.iterrows():
|
||||||
|
if annot['label'] not in ACTIVE_EVENT_LABELS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start = int(annot['start_row'])
|
||||||
|
end = int(annot['end_row'])
|
||||||
|
end = min(end, all_raw.shape[1])
|
||||||
|
if start >= end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_raw = all_raw[ch_idx_clean, start:end]
|
||||||
|
data_theta = all_theta[ch_idx_theta, start:end]
|
||||||
|
data_alpha = all_alpha[ch_idx_alpha, start:end]
|
||||||
|
data_beta = all_beta[ch_idx_beta, start:end]
|
||||||
|
|
||||||
|
features = extract_epoch_features(data_raw, data_theta,
|
||||||
|
data_alpha, data_beta)
|
||||||
|
features_list.append(features)
|
||||||
|
labels_list.append(task_type)
|
||||||
|
|
||||||
|
return features_list, labels_list
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_run_features_from_arrays(all_raw, all_theta, all_alpha, all_beta,
|
||||||
|
ch_idx_clean, ch_idx_theta, ch_idx_alpha,
|
||||||
|
ch_idx_beta,
|
||||||
|
annotations, task_type, sample_offset,
|
||||||
|
max_samples):
|
||||||
|
"""Extract Hjorth features from one run within pre-fetched concatenated arrays.
|
||||||
|
|
||||||
|
This avoids repeated get_data() calls by working directly on numpy arrays.
|
||||||
|
"""
|
||||||
|
features_list = []
|
||||||
|
labels_list = []
|
||||||
|
|
||||||
|
for _, annot in annotations.iterrows():
|
||||||
|
if annot['label'] not in ACTIVE_EVENT_LABELS:
|
||||||
|
continue
|
||||||
|
|
||||||
|
start = sample_offset + int(annot['start_row'])
|
||||||
|
end = sample_offset + int(annot['end_row'])
|
||||||
|
end = min(end, max_samples)
|
||||||
|
if start >= end:
|
||||||
|
continue
|
||||||
|
|
||||||
|
data_raw = all_raw[ch_idx_clean, start:end]
|
||||||
|
data_theta = all_theta[ch_idx_theta, start:end]
|
||||||
|
data_alpha = all_alpha[ch_idx_alpha, start:end]
|
||||||
|
data_beta = all_beta[ch_idx_beta, start:end]
|
||||||
|
|
||||||
|
features = extract_epoch_features(data_raw, data_theta,
|
||||||
|
data_alpha, data_beta)
|
||||||
|
features_list.append(features)
|
||||||
|
labels_list.append(task_type)
|
||||||
|
|
||||||
|
return features_list, labels_list
|
||||||
|
|
||||||
|
|
||||||
|
# =============================================================================
|
||||||
|
# Classification (Section 3.6)
|
||||||
|
# =============================================================================
|
||||||
|
|
||||||
|
def classify_with_averaging(X_train, y_train, X_test, y_test,
|
||||||
|
X_val=None, y_val=None, verbose=True):
|
||||||
|
"""
|
||||||
|
Train SVM N_RUNS times with different random seeds and report mean ± std.
|
||||||
|
|
||||||
|
Paper: "results were averaged by running the model five times"
|
||||||
|
|
||||||
|
Uses MinMaxScaler [0,1] per Equation 8, fitted on training data only.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
results : dict with keys like 'test_acc', 'val_acc', etc.
|
||||||
|
Each value is a list of N_RUNS floats.
|
||||||
|
"""
|
||||||
|
# Normalize (Equation 8)
|
||||||
|
scaler = MinMaxScaler(feature_range=(0, 1))
|
||||||
|
X_train_s = scaler.fit_transform(X_train)
|
||||||
|
X_test_s = scaler.transform(X_test)
|
||||||
|
X_val_s = scaler.transform(X_val) if X_val is not None else None
|
||||||
|
|
||||||
|
results = {
|
||||||
|
'train_acc': [], 'test_acc': [], 'val_acc': [],
|
||||||
|
'test_me_recall': [], 'test_mi_recall': [],
|
||||||
|
'val_me_recall': [], 'val_mi_recall': [],
|
||||||
|
}
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f"\n Training SVM {N_RUNS} times (C={SVM_C}, gamma={SVM_GAMMA}, "
|
||||||
|
f"kernel={SVM_KERNEL}, deflation ICA)")
|
||||||
|
print(f" {'Run':>5} | {'Seed':>6} | {'Train':>7} | {'Test':>7} | "
|
||||||
|
f"{'Val':>7} | {'Test ME':>8} | {'Test MI':>8}")
|
||||||
|
print(f" {'-'*65}")
|
||||||
|
|
||||||
|
for run_i, seed in enumerate(RANDOM_SEEDS):
|
||||||
|
svm = SVC(kernel=SVM_KERNEL, C=SVM_C, gamma=SVM_GAMMA,
|
||||||
|
random_state=seed)
|
||||||
|
svm.fit(X_train_s, y_train)
|
||||||
|
|
||||||
|
train_acc = accuracy_score(y_train, svm.predict(X_train_s))
|
||||||
|
test_acc = accuracy_score(y_test, svm.predict(X_test_s))
|
||||||
|
results['train_acc'].append(train_acc)
|
||||||
|
results['test_acc'].append(test_acc)
|
||||||
|
|
||||||
|
# Per-class recall
|
||||||
|
test_cm = confusion_matrix(y_test, svm.predict(X_test_s),
|
||||||
|
labels=['ME', 'MI'])
|
||||||
|
test_me = test_cm[0, 0] / test_cm[0].sum() if test_cm[0].sum() > 0 else 0
|
||||||
|
test_mi = test_cm[1, 1] / test_cm[1].sum() if test_cm[1].sum() > 0 else 0
|
||||||
|
results['test_me_recall'].append(test_me)
|
||||||
|
results['test_mi_recall'].append(test_mi)
|
||||||
|
|
||||||
|
val_acc = 0
|
||||||
|
val_me = 0
|
||||||
|
val_mi = 0
|
||||||
|
if X_val is not None:
|
||||||
|
val_acc = accuracy_score(y_val, svm.predict(X_val_s))
|
||||||
|
val_cm = confusion_matrix(y_val, svm.predict(X_val_s),
|
||||||
|
labels=['ME', 'MI'])
|
||||||
|
val_me = val_cm[0, 0] / val_cm[0].sum() if val_cm[0].sum() > 0 else 0
|
||||||
|
val_mi = val_cm[1, 1] / val_cm[1].sum() if val_cm[1].sum() > 0 else 0
|
||||||
|
|
||||||
|
results['val_acc'].append(val_acc)
|
||||||
|
results['val_me_recall'].append(val_me)
|
||||||
|
results['val_mi_recall'].append(val_mi)
|
||||||
|
|
||||||
|
if verbose:
|
||||||
|
print(f" {run_i+1:>5} | {seed:>6} | {train_acc:>7.2%} | "
|
||||||
|
f"{test_acc:>7.2%} | {val_acc:>7.2%} | "
|
||||||
|
f"{test_me:>8.2%} | {test_mi:>8.2%}")
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_results_summary(results):
|
||||||
|
"""Print mean ± std summary matching the paper's Table 6/7 format."""
|
||||||
|
print(f"\n {'Metric':<28s} {'Mean ± Std':>14s}")
|
||||||
|
print(f" {'-'*44}")
|
||||||
|
|
||||||
|
labels = {
|
||||||
|
'train_acc': 'Training Accuracy',
|
||||||
|
'test_acc': 'Testing Accuracy',
|
||||||
|
'val_acc': 'Validation Accuracy',
|
||||||
|
'test_me_recall': 'Test ME Recall (TP rate)',
|
||||||
|
'test_mi_recall': 'Test MI Recall (TN rate)',
|
||||||
|
'val_me_recall': 'Val ME Recall (TP rate)',
|
||||||
|
'val_mi_recall': 'Val MI Recall (TN rate)',
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, label in labels.items():
|
||||||
|
vals = np.array(results[key])
|
||||||
|
print(f" {label:<28s} {vals.mean():>6.2%} ± {vals.std():.2%}")
|
||||||
|
|
||||||
|
print(f"\n Paper's Method 2 targets:")
|
||||||
|
print(f" Overall accuracy: 68.8 ± 0.71%")
|
||||||
|
print(f" ME recall: 68.17%")
|
||||||
|
print(f" MI recall: 68.41%")
|
||||||
245
reproduction_notebook.ipynb
Normal file
245
reproduction_notebook.ipynb
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
{
|
||||||
|
"cells": [
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"# Reproducing: Mwata-Velu et al. (2023)\n",
|
||||||
|
"\n",
|
||||||
|
"**Paper**: \"EEG-BCI Features Discrimination between Executed and Imagined Movements Based on FastICA, Hjorth Parameters, and SVM\" \n",
|
||||||
|
"**Journal**: Mathematics 2023, 11, 4409 \n",
|
||||||
|
"**DOI**: https://doi.org/10.3390/math11214409\n",
|
||||||
|
"\n",
|
||||||
|
"## ICA Strategy\n",
|
||||||
|
"\n",
|
||||||
|
"The paper does not specify whether ICA is fitted per-run, per-subject, or globally.\n",
|
||||||
|
"This notebook supports all three strategies via `config.ICA_STRATEGY`.\n",
|
||||||
|
"Change it in `config.py` before running.\n",
|
||||||
|
"\n",
|
||||||
|
"## Key Implementation Decisions\n",
|
||||||
|
"\n",
|
||||||
|
"| Decision | Paper says | We do | Rationale |\n",
|
||||||
|
"|----------|-----------|-------|----------|\n",
|
||||||
|
"| Pipeline order | Figure 1: filter→ICA; Algorithm 1: ICA with internal sub-band eval | ICA then sub-band eval | Energy criterion is meaningless on pre-filtered data |\n",
|
||||||
|
"| ICA algorithm | Gram-Schmidt (Algorithm 1, Step 3) | `algorithm='deflation'` | Deflation uses Gram-Schmidt |\n",
|
||||||
|
"| Energy criterion | ∀χ ∈ {α, β, **γ**} | ∀χ ∈ {**θ**, α, β} | γ never defined; likely typo for θ |\n",
|
||||||
|
"| ICA scope | Not specified | Configurable | Reproducibility variable |\n",
|
||||||
|
"| Classification Method | Methods 1 and 2 | Method 2 only (cross-subject) | Method 1 split is contradictory |"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 1. Setup"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"import numpy as np\n",
|
||||||
|
"import mne\n",
|
||||||
|
"import warnings\n",
|
||||||
|
"warnings.filterwarnings('ignore', category=RuntimeWarning) # MNE verbosity\n",
|
||||||
|
"mne.set_log_level('WARNING')\n",
|
||||||
|
"\n",
|
||||||
|
"from config import (\n",
|
||||||
|
" ICA_STRATEGY, DATA_DIR, TARGET_CHANNELS,\n",
|
||||||
|
" TRAIN_SUBJECTS, TEST_SUBJECTS, VAL_SUBJECTS, TARGET_RUNS,\n",
|
||||||
|
")\n",
|
||||||
|
"from pipeline import (\n",
|
||||||
|
" process_subject_per_run, process_subject_per_subject,\n",
|
||||||
|
" process_subject_global, fit_global_ica,\n",
|
||||||
|
" classify_with_averaging, print_results_summary,\n",
|
||||||
|
")\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Data directory: {DATA_DIR}\")\n",
|
||||||
|
"print(f\"ICA strategy: {ICA_STRATEGY}\")\n",
|
||||||
|
"print(f\"Target channels: {TARGET_CHANNELS}\")\n",
|
||||||
|
"print(f\"Subjects: {len(TRAIN_SUBJECTS)} train, {len(TEST_SUBJECTS)} test, {len(VAL_SUBJECTS)} val\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 2. Process All Subjects (Method 2: Cross-Subject Split)\n",
|
||||||
|
"\n",
|
||||||
|
"Table 4 from the paper:\n",
|
||||||
|
"- Training: Subjects 1-83 (80%)\n",
|
||||||
|
"- Testing: Subjects 84-93 (10%)\n",
|
||||||
|
"- Validation: Subjects 94-103 (10%)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# If using global ICA, fit it first on training subjects\n",
|
||||||
|
"global_ica_results = None\n",
|
||||||
|
"if ICA_STRATEGY == 'global':\n",
|
||||||
|
" print(\"Fitting global ICA on all training subjects...\")\n",
|
||||||
|
" global_ica_results = fit_global_ica(TRAIN_SUBJECTS)\n",
|
||||||
|
" print(\"Global ICA fitting complete.\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"def process_subject_group(subject_list, group_name):\n",
|
||||||
|
" \"\"\"Process a group of subjects and collect features/labels.\"\"\"\n",
|
||||||
|
" all_features = []\n",
|
||||||
|
" all_labels = []\n",
|
||||||
|
" total_kept = 0\n",
|
||||||
|
" n_processed = 0\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"\\n{'='*60}\")\n",
|
||||||
|
" print(f\"PROCESSING {group_name} ({len(subject_list)} subjects)\")\n",
|
||||||
|
" print(f\"{'='*60}\")\n",
|
||||||
|
" \n",
|
||||||
|
" for subject_num in subject_list:\n",
|
||||||
|
" try:\n",
|
||||||
|
" if ICA_STRATEGY == 'global':\n",
|
||||||
|
" feats, labels, runs, n_kept = process_subject_global(\n",
|
||||||
|
" subject_num, global_ica_results, verbose=True)\n",
|
||||||
|
" elif ICA_STRATEGY == 'per_subject':\n",
|
||||||
|
" feats, labels, runs, n_kept = process_subject_per_subject(\n",
|
||||||
|
" subject_num, verbose=True)\n",
|
||||||
|
" elif ICA_STRATEGY == 'per_run':\n",
|
||||||
|
" feats, labels, runs, n_kept = process_subject_per_run(\n",
|
||||||
|
" subject_num, verbose=True)\n",
|
||||||
|
" \n",
|
||||||
|
" if len(feats) > 0:\n",
|
||||||
|
" combined_feats = np.vstack(feats)\n",
|
||||||
|
" combined_labels = np.concatenate(labels)\n",
|
||||||
|
" all_features.append(combined_feats)\n",
|
||||||
|
" all_labels.append(combined_labels)\n",
|
||||||
|
" total_kept += n_kept\n",
|
||||||
|
" n_processed += 1\n",
|
||||||
|
" \n",
|
||||||
|
" n_me = np.sum(combined_labels == 'ME')\n",
|
||||||
|
" n_mi = np.sum(combined_labels == 'MI')\n",
|
||||||
|
" print(f\" S{subject_num:03d}: {len(combined_labels)} epochs \"\n",
|
||||||
|
" f\"(ME={n_me}, MI={n_mi}), {n_kept} ICA components kept\")\n",
|
||||||
|
" except Exception as e:\n",
|
||||||
|
" print(f\" S{subject_num:03d}: ERROR - {e}\")\n",
|
||||||
|
" continue\n",
|
||||||
|
" \n",
|
||||||
|
" if len(all_features) > 0:\n",
|
||||||
|
" X = np.vstack(all_features)\n",
|
||||||
|
" y = np.concatenate(all_labels)\n",
|
||||||
|
" else:\n",
|
||||||
|
" X = np.array([])\n",
|
||||||
|
" y = np.array([])\n",
|
||||||
|
" \n",
|
||||||
|
" print(f\"\\n {group_name} total: {len(y)} samples \"\n",
|
||||||
|
" f\"(ME={np.sum(y=='ME')}, MI={np.sum(y=='MI')})\")\n",
|
||||||
|
" if n_processed > 0:\n",
|
||||||
|
" print(f\" Avg ICA components kept: {total_kept / n_processed:.1f}\")\n",
|
||||||
|
" \n",
|
||||||
|
" return X, y"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_train, y_train = process_subject_group(TRAIN_SUBJECTS, \"TRAINING\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_test, y_test = process_subject_group(TEST_SUBJECTS, \"TESTING\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"X_val, y_val = process_subject_group(VAL_SUBJECTS, \"VALIDATION\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"# Summary\n",
|
||||||
|
"print(f\"\\n{'='*60}\")\n",
|
||||||
|
"print(f\"DATASET SUMMARY\")\n",
|
||||||
|
"print(f\"{'='*60}\")\n",
|
||||||
|
"print(f\" Training: {X_train.shape} — ME={np.sum(y_train=='ME')}, MI={np.sum(y_train=='MI')}\")\n",
|
||||||
|
"print(f\" Testing: {X_test.shape} — ME={np.sum(y_test=='ME')}, MI={np.sum(y_test=='MI')}\")\n",
|
||||||
|
"print(f\" Validation: {X_val.shape} — ME={np.sum(y_val=='ME')}, MI={np.sum(y_val=='MI')}\")\n",
|
||||||
|
"print(f\"\\n Paper expects: ~6972 train, ~840 test, ~840 val\")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"## 4. Classification with 5-Run Averaging"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"results = classify_with_averaging(\n",
|
||||||
|
" X_train, y_train,\n",
|
||||||
|
" X_test, y_test,\n",
|
||||||
|
" X_val, y_val,\n",
|
||||||
|
")"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": null,
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [],
|
||||||
|
"source": [
|
||||||
|
"print_results_summary(results)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"metadata": {
|
||||||
|
"kernelspec": {
|
||||||
|
"display_name": "THESIS",
|
||||||
|
"language": "python",
|
||||||
|
"name": "thesis"
|
||||||
|
},
|
||||||
|
"language_info": {
|
||||||
|
"codemirror_mode": {
|
||||||
|
"name": "ipython",
|
||||||
|
"version": 3
|
||||||
|
},
|
||||||
|
"file_extension": ".py",
|
||||||
|
"mimetype": "text/x-python",
|
||||||
|
"name": "python",
|
||||||
|
"nbconvert_exporter": "python",
|
||||||
|
"pygments_lexer": "ipython3",
|
||||||
|
"version": "3.9.23"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 4
|
||||||
|
}
|
||||||
7
requirements.txt
Normal file
7
requirements.txt
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
mne>=1.6
|
||||||
|
numpy>=1.24
|
||||||
|
pandas>=2.0
|
||||||
|
scipy>=1.11
|
||||||
|
scikit-learn>=1.3
|
||||||
|
matplotlib>=3.7
|
||||||
|
seaborn>=0.12
|
||||||
Reference in New Issue
Block a user