1
0
Files
Mwata-Velu_et_al_2023/pipeline.py
2026-04-09 08:21:30 -07:00

911 lines
32 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""
Core signal processing pipeline for reproducing Mwata-Velu et al. (2023).
Contains:
- Chebyshev Type II sub-band filtering (Section 3.3)
- FastICA denoising with energy concentration criterion (Section 3.4, Algorithm 1)
- Hjorth parameter feature extraction (Section 3.5)
- SVM classification with 5-run averaging (Section 3.6)
PIPELINE ORDER (Algorithm 1 interpretation):
1. ICA on broadband 19-channel data
- Step 6 internally filters each back-projected IC into θ/α/β sub-bands
to evaluate energy concentration in selected channels
- Step 7 keeps only ICs meeting the threshold in ANY sub-band (OR logic).
The paper specifies ALL (AND logic), but this causes zero components
to pass for many subjects. OR logic is used as a documented deviation.
- Step 8 reconstructs cleaned broadband signal from kept ICs
2. Filter the *cleaned* broadband signal into θ/α/β for Hjorth Activity features
3. Compute Hjorth Mobility/Complexity on the *cleaned* unfiltered signal
This resolves the paper's contradiction between Figure 1 and Algorithm 1. The
Algorithm 1 reading is preferred because the energy concentration criterion is
meaningless on pre-filtered data (e.g., filtering theta-band data into alpha
yields ~zero).
"""
import numpy as np
from scipy.signal import cheby2, filtfilt
import mne
from mne.preprocessing import ICA
from sklearn.svm import SVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from config import (
ICA_CHANNELS, SELECTED_CHANNELS, SUB_BANDS, TARGET_RUNS,
ICA_N_COMPONENTS, ICA_ENERGY_THRESHOLD, ICA_MAX_ITER, ICA_TOL,
TARGET_CHANNELS, EXECUTION_RUNS, IMAGERY_RUNS, ACTIVE_EVENT_LABELS,
SVM_C, SVM_GAMMA, SVM_KERNEL, N_RUNS, RANDOM_SEEDS,
)
from data_loading import (
load_signal_file, load_annotation_file, create_mne_raw, add_annotations_to_raw,
)
# =============================================================================
# Sub-band filtering (Section 3.3)
# =============================================================================
def design_chebyshev_filter(lowcut, highcut, fs, order=5, rs=34):
"""
Design Chebyshev Type II bandpass filter.
Paper specifies:
- 5th-order Chebyshev Type II
- Stop-band attenuation of -34 dB
- Transition bands reaching 80% of the gain
"""
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = cheby2(order, rs, [low, high], btype='band')
return b, a
def apply_bandpass_filter(raw, lowcut, highcut, picks=None):
"""
Apply Chebyshev Type II bandpass filter.
Parameters
----------
raw : mne.io.Raw
lowcut, highcut : float
picks : list of str, optional
Channel names to filter. If None, filters all channels.
When specified, returns a Raw containing only those channels.
Returns a new Raw object with filtered data.
"""
if picks is not None:
raw_sub = raw.copy().pick(picks)
else:
raw_sub = raw.copy()
b, a = design_chebyshev_filter(lowcut, highcut, raw_sub.info['sfreq'])
data = raw_sub.get_data()
filtered_data = filtfilt(b, a, data, axis=1)
raw_filtered = mne.io.RawArray(filtered_data, raw_sub.info, verbose=False)
raw_filtered.set_annotations(raw.annotations)
return raw_filtered
def _cheby2_bandpass(data, sfreq, low, high, order=5, rs=34.0):
"""
Apply Chebyshev Type II bandpass filter to a numpy array.
Used internally by the ICA energy concentration criterion (Step 6-7).
Parameters
----------
data : ndarray, shape (n_signals, n_samples)
sfreq : float
low, high : float, pass-band edges in Hz
Returns
-------
filtered : ndarray, same shape as data
"""
b, a = design_chebyshev_filter(low, high, sfreq, order=order, rs=rs)
return filtfilt(b, a, data, axis=-1)
# =============================================================================
# FastICA denoising (Section 3.4, Algorithm 1)
# =============================================================================
def apply_fastica_denoising(raw, n_components=ICA_N_COMPONENTS,
energy_threshold=ICA_ENERGY_THRESHOLD,
verbose=True):
"""
Denoise EEG with FastICA following Algorithm 1 of Mwata-Velu et al. (2023).
This function operates on BROADBAND data. Sub-band filtering is used
internally only for the energy concentration criterion (Steps 6-7).
Steps mapped to the paper
-------------------------
Steps 1-5: FastICA decomposition of 19 channels into n_components ICs.
- Step 1: Random weight vector initialization
- Step 2: Fixed-point update with log-cosh non-linearity (Eq. 3-4)
- Step 3: Gram-Schmidt orthogonalization (deflation approach)
- Step 4: Weight vector normalization
- Step 5: Convergence check; repeat 2-4 if not converged
NOTE: We use algorithm='deflation' to match the paper's explicit
description of Gram-Schmidt orthogonalization. sklearn's default
('parallel') uses symmetric decorrelation instead, which is a
different algorithm. See Hyvärinen & Oja (1997, 2000).
Step 6: For every IC k, back-project to 19 channels, filter into each
sub-band, and compute per-channel variance.
Step 7: Retain IC k if, for ANY sub-band, the ratio
sum(Selected) / sum(All) >= threshold (paper: 0.35).
NOTE: The paper specifies AND logic (all sub-bands must pass),
but this is too strict — many subjects yield zero kept components.
We use OR logic (any sub-band passing suffices) as a documented
reproducibility deviation.
NOTE: The paper writes {γ, α, β}" with gamma instead of theta.
The gamma band is never defined in the paper. We interpret this as a
typo and use θ/α/β as defined in Section 3.3.
Step 8: Reconstruct û = A_p @ S_p using only kept components.
Parameters
----------
raw : mne.io.Raw
Must contain (at least) the 19 ICA_CHANNELS.
Should be BROADBAND (unfiltered) data.
n_components : int
Number of ICs to extract. Paper uses 19.
energy_threshold : float
Per-sub-band concentration ratio threshold. Paper: 0.35.
verbose : bool
Print progress information.
Returns
-------
raw_clean : mne.io.Raw
Copy of raw with denoised ICA channels (still broadband).
ica : mne.preprocessing.ICA
Fitted ICA object.
kept_components : list of int
Indices of ICs that passed the Step 7 criterion.
"""
sfreq = raw.info['sfreq']
if verbose:
print(f" [FastICA] {n_components} components, "
f"threshold={energy_threshold}, deflation mode")
# ------------------------------------------------------------------
# Steps 1-5: FastICA decomposition on broadband data
# ------------------------------------------------------------------
raw_ica = raw.copy().pick(ICA_CHANNELS)
ica = ICA(
n_components=n_components,
method='fastica',
fit_params={
'max_iter': ICA_MAX_ITER,
'tol': ICA_TOL,
'algorithm': 'deflation', # Paper's Algorithm 1, Step 3
},
random_state=42,
)
ica.fit(raw_ica)
# Mixing matrix A and IC time-series S
mixing = ica.mixing_matrix_
unmixing = ica.unmixing_matrix_
raw_data = raw_ica.get_data()
icasigs = unmixing @ raw_data
# Channel index maps
all_ch_names = list(raw_ica.ch_names)
sel_idx = [all_ch_names.index(ch) for ch in SELECTED_CHANNELS]
# ------------------------------------------------------------------
# Step 6: Evaluate each IC by sub-band (internal filtering)
#
# For each IC k:
# 1. Back-project to 19 channels: signal[c, t] = mixing[c, k] * IC_k[t]
# 2. Filter the 19-channel signal into each sub-band
# 3. Compute variance per channel
# ------------------------------------------------------------------
ic_channel_var = {name: [] for name, _, _ in SUB_BANDS}
for k in range(n_components):
signal_19ch = mixing[:, k, np.newaxis] * icasigs[k, :] # (19, T)
for name, low, high in SUB_BANDS:
filtered = _cheby2_bandpass(signal_19ch, sfreq, low, high)
var_per_channel = np.var(filtered, axis=1, ddof=1)
ic_channel_var[name].append(var_per_channel)
for name in ic_channel_var:
ic_channel_var[name] = np.array(ic_channel_var[name])
# ------------------------------------------------------------------
# Step 7: Energy concentration criterion
#
# Paper states: ∀χ ∈ {θ, α, β} (AND logic).
# However, AND across all three bands is extremely strict and causes
# zero components to pass for many subjects. We use OR logic instead:
# keep IC k if it passes the threshold in ANY sub-band. This is more
# defensible — a component concentrated in motor cortex in alpha is
# still a neural signal worth keeping even if its theta energy is
# diffuse. Documented as a reproducibility deviation.
# ------------------------------------------------------------------
kept_components = []
if verbose:
print(f" {'IC':>4} | {'theta':>7} | {'alpha':>7} | {'beta':>7} | {'kept?':>5}")
print(f" {'-'*46}")
for k in range(n_components):
ratios = {}
keep = False # OR logic: start False, flip to True if ANY passes
# keep = True # AND logic: start True, flip to False if ANY fails
for name, _, _ in SUB_BANDS:
var_all = ic_channel_var[name][k]
num = var_all[sel_idx].sum()
den = var_all.sum()
ratio = num / den if den > 0 else 0.0
ratios[name] = ratio
if ratio >= energy_threshold:
keep = True # passes in at least one sub-band
# if ratio < energy_threshold:
# keep = False # fails in at least one sub-band
if verbose:
status = "YES" if keep else "no"
print(f" {k:>4} | {ratios['theta']:>7.3f} | "
f"{ratios['alpha']:>7.3f} | {ratios['beta']:>7.3f} | {status:>5}")
if keep:
kept_components.append(k)
if verbose:
print(f" [FastICA] Kept {len(kept_components)}/{n_components} components")
# ------------------------------------------------------------------
# Step 8: Reconstruct û = A_p @ S_p (broadband cleaned signal)
# ------------------------------------------------------------------
if len(kept_components) == 0:
if verbose:
print(" [FastICA] WARNING: no components passed threshold "
"even with OR logic. Returning original data unchanged.")
return raw.copy(), ica, kept_components
A_p = mixing[:, kept_components]
S_p = icasigs[kept_components, :]
reconstructed = A_p @ S_p
raw_clean = mne.io.RawArray(reconstructed, raw_ica.info, verbose=False)
raw_clean.set_annotations(raw.annotations)
return raw_clean, ica, kept_components
# =============================================================================
# Hjorth parameter extraction (Section 3.5)
# =============================================================================
def compute_hjorth_parameters(signal_array):
"""
Compute Hjorth parameters: Activity, Mobility, Complexity.
Activity: Variance of the signal (Eq. 5)
Mobility: sqrt(Var(d/dt signal) / Var(signal)) (Eq. 6)
Complexity: Mobility(d/dt signal) / Mobility(signal) (Eq. 7)
The derivative d/dt u(i) is approximated as u(i) - u(i-1) per the paper.
Parameters
----------
signal_array : np.ndarray, 1D
Returns
-------
activity, mobility, complexity : float
"""
activity = np.var(signal_array)
d1 = np.diff(signal_array)
var_d1 = np.var(d1)
if activity > 0:
mobility = np.sqrt(var_d1 / activity)
else:
mobility = 0.0
d2 = np.diff(d1)
var_d2 = np.var(d2)
if var_d1 > 0 and mobility > 0:
mobility_d1 = np.sqrt(var_d2 / var_d1)
complexity = mobility_d1 / mobility
else:
complexity = 0.0
return activity, mobility, complexity
def extract_epoch_features(data_raw, data_theta, data_alpha, data_beta,
channels=TARGET_CHANNELS):
"""
Extract Hjorth features for one epoch across all target channels.
Feature set matches Table 5, Set 3: Aθ, Aα, Aβ, Mraw, Craw per channel.
Total features = len(channels) * 5.
Parameters
----------
data_raw, data_theta, data_alpha, data_beta : np.ndarray
Shape (n_channels, n_samples) epoch data for each band.
Returns
-------
features : np.ndarray of shape (n_channels * 5,)
"""
features = []
for i in range(len(channels)):
a_theta, _, _ = compute_hjorth_parameters(data_theta[i])
a_alpha, _, _ = compute_hjorth_parameters(data_alpha[i])
a_beta, _, _ = compute_hjorth_parameters(data_beta[i])
_, m_raw, c_raw = compute_hjorth_parameters(data_raw[i])
features.extend([a_theta, a_alpha, a_beta, m_raw, c_raw])
return np.array(features)
# =============================================================================
# Subject processing (orchestration)
#
# In all strategies below, the pipeline is:
# 1. ICA on broadband data (1 fit, not 4)
# 2. Filter the CLEANED broadband signal into θ/α/β for Activity features
# 3. Compute Mobility/Complexity on the CLEANED unfiltered signal
# =============================================================================
def process_subject_per_run(subject_num, target_runs=None, apply_ica=True,
verbose=True):
"""
Process one subject with ICA fitted independently per run.
This gives ICA the least amount of data to work with (~19,200 samples for
a 2-minute run on 19 channels).
Returns
-------
all_features : list of np.ndarray, one per run
all_labels : list of np.ndarray, one per run
run_numbers : list of int
total_kept : int (total across all ICA fits)
"""
if target_runs is None:
target_runs = TARGET_RUNS
all_features = []
all_labels = []
run_numbers = []
total_kept = 0
n_ica_fits = 0
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
annotations = load_annotation_file(subject_num, run_num)
except Exception as e:
if verbose:
print(f" Could not load S{subject_num:03d} R{run_num:02d}: {e}")
continue
raw = create_mne_raw(signal_data)
raw = add_annotations_to_raw(raw, annotations)
# ICA on broadband data (1 fit per run; sub-band eval is internal)
if apply_ica:
raw_clean, _, kept = apply_fastica_denoising(raw, verbose=verbose)
total_kept += len(kept)
n_ica_fits += 1
else:
raw_clean = raw
# Filter the CLEANED signal into sub-bands for Activity features
# Only filter the channels we actually need for feature extraction
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
# Determine task type from run number
if run_num in EXECUTION_RUNS:
task_type = 'ME'
elif run_num in IMAGERY_RUNS:
task_type = 'MI'
else:
continue
# Extract features from active events
features_list, labels_list = _extract_run_features(
raw_clean, raw_theta, raw_alpha, raw_beta,
annotations, task_type
)
if len(features_list) > 0:
all_features.append(np.array(features_list))
all_labels.append(np.array(labels_list))
run_numbers.append(run_num)
return all_features, all_labels, run_numbers, total_kept
def process_subject_per_subject(subject_num, target_runs=None, apply_ica=True,
verbose=True):
"""
Process one subject with ICA fitted once on all runs concatenated.
This gives ICA more data for better decomposition (~115,200 samples
for 12 runs).
Returns
-------
all_features : list of np.ndarray, one per run
all_labels : list of np.ndarray, one per run
run_numbers : list of int
n_kept_components : int
"""
if target_runs is None:
target_runs = TARGET_RUNS
# Step 1: Load and concatenate all runs
raw_runs = []
annotations_runs = []
run_lengths = []
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
annotations = load_annotation_file(subject_num, run_num)
raw = create_mne_raw(signal_data)
raw = add_annotations_to_raw(raw, annotations)
raw_runs.append(raw)
annotations_runs.append(annotations)
run_lengths.append(raw.n_times)
except Exception as e:
if verbose:
print(f" Could not load S{subject_num:03d} R{run_num:02d}: {e}")
continue
if len(raw_runs) == 0:
return [], [], [], 0
raw_concat = mne.concatenate_raws(raw_runs, preload=True)
if verbose:
duration = raw_concat.n_times / raw_concat.info['sfreq']
print(f" S{subject_num:03d}: {len(raw_runs)} runs, {duration:.0f}s")
# Step 2: ICA once on broadband concatenated data
n_kept = 0
if apply_ica:
raw_clean, _, kept = apply_fastica_denoising(raw_concat, verbose=verbose)
n_kept = len(kept)
else:
raw_clean = raw_concat
# Step 3: Filter the CLEANED signal into sub-bands for Activity features
# Only filter the channels we actually need for feature extraction
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
# Step 4: Extract features per run using global sample offsets
all_raw_data = raw_clean.get_data()
all_theta_data = raw_theta.get_data()
all_alpha_data = raw_alpha.get_data()
all_beta_data = raw_beta.get_data()
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
all_features = []
all_labels = []
run_numbers = []
sample_offset = 0
for run_idx, run_num in enumerate(target_runs):
if run_idx >= len(raw_runs):
continue
annotations = annotations_runs[run_idx]
run_length = run_lengths[run_idx]
if run_num in EXECUTION_RUNS:
task_type = 'ME'
elif run_num in IMAGERY_RUNS:
task_type = 'MI'
else:
sample_offset += run_length
continue
features_list, labels_list = _extract_run_features_from_arrays(
all_raw_data, all_theta_data, all_alpha_data, all_beta_data,
ch_idx_clean, ch_idx_theta, ch_idx_alpha, ch_idx_beta,
annotations, task_type, sample_offset, all_raw_data.shape[1]
)
if len(features_list) > 0:
all_features.append(np.array(features_list))
all_labels.append(np.array(labels_list))
run_numbers.append(run_num)
sample_offset += run_length
return all_features, all_labels, run_numbers, n_kept
def process_subject_global(subject_num, global_ica_results,
target_runs=None, verbose=True):
"""
Process one subject using a pre-fitted global ICA.
The global ICA is fitted once on all training subjects' broadband data.
Each subject's data is then projected through the same unmixing matrix.
Parameters
----------
subject_num : int
global_ica_results : dict
Keys: 'mixing', 'unmixing', 'kept_components'
From fit_global_ica().
target_runs : list of int, optional
verbose : bool
Returns
-------
all_features, all_labels, run_numbers, n_kept
"""
if target_runs is None:
target_runs = TARGET_RUNS
# Load and concatenate
raw_runs = []
annotations_runs = []
run_lengths = []
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
annotations = load_annotation_file(subject_num, run_num)
raw = create_mne_raw(signal_data)
raw = add_annotations_to_raw(raw, annotations)
raw_runs.append(raw)
annotations_runs.append(annotations)
run_lengths.append(raw.n_times)
except Exception:
continue
if len(raw_runs) == 0:
return [], [], [], 0
raw_concat = mne.concatenate_raws(raw_runs, preload=True)
# Apply pre-fitted global ICA on broadband data
ica_tuple = (global_ica_results['mixing'],
global_ica_results['unmixing'],
global_ica_results['kept_components'])
raw_clean = _apply_precomputed_ica(raw_concat, ica_tuple)
n_kept = len(global_ica_results['kept_components'])
# Filter the CLEANED signal into sub-bands for Activity features
# Only filter the channels we actually need for feature extraction
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
# Extract features per run
all_raw_data = raw_clean.get_data()
all_theta_data = raw_theta.get_data()
all_alpha_data = raw_alpha.get_data()
all_beta_data = raw_beta.get_data()
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
all_features = []
all_labels = []
run_numbers = []
sample_offset = 0
for run_idx, run_num in enumerate(target_runs):
if run_idx >= len(raw_runs):
continue
annotations = annotations_runs[run_idx]
run_length = run_lengths[run_idx]
if run_num in EXECUTION_RUNS:
task_type = 'ME'
elif run_num in IMAGERY_RUNS:
task_type = 'MI'
else:
sample_offset += run_length
continue
features_list, labels_list = _extract_run_features_from_arrays(
all_raw_data, all_theta_data, all_alpha_data, all_beta_data,
ch_idx_clean, ch_idx_theta, ch_idx_alpha, ch_idx_beta,
annotations, task_type, sample_offset, all_raw_data.shape[1]
)
if len(features_list) > 0:
all_features.append(np.array(features_list))
all_labels.append(np.array(labels_list))
run_numbers.append(run_num)
sample_offset += run_length
return all_features, all_labels, run_numbers, n_kept
def fit_global_ica(subject_list, target_runs=None, verbose=True):
"""
Fit ICA once on all subjects' broadband data concatenated.
Returns
-------
global_ica_results : dict with keys 'mixing', 'unmixing', 'kept_components'
"""
if target_runs is None:
target_runs = TARGET_RUNS
all_raws = []
for subject_num in subject_list:
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
raw = create_mne_raw(signal_data)
all_raws.append(raw)
except Exception:
continue
if len(all_raws) == 0:
raise ValueError("No data loaded for global ICA")
if verbose:
print(f" Global ICA: concatenating {len(all_raws)} runs...")
global_raw = mne.concatenate_raws(all_raws, preload=True)
if verbose:
duration = global_raw.n_times / global_raw.info['sfreq']
print(f" Global ICA: {duration:.0f}s total data")
# Fit ICA once on broadband data
if verbose:
print(f" Fitting global ICA on broadband data...")
_, ica, kept = apply_fastica_denoising(global_raw, verbose=verbose)
return {
'mixing': ica.mixing_matrix_,
'unmixing': ica.unmixing_matrix_,
'kept_components': kept,
}
def _apply_precomputed_ica(raw, ica_result):
"""
Apply a pre-fitted ICA decomposition to reconstruct cleaned signals.
Parameters
----------
raw : mne.io.Raw
ica_result : tuple (mixing, unmixing, kept_components)
"""
mixing, unmixing, kept_components = ica_result
if len(kept_components) == 0:
return raw.copy()
raw_ica = raw.copy().pick(ICA_CHANNELS)
raw_data = raw_ica.get_data()
icasigs = unmixing @ raw_data
A_p = mixing[:, kept_components]
S_p = icasigs[kept_components, :]
reconstructed = A_p @ S_p
raw_clean = mne.io.RawArray(reconstructed, raw_ica.info, verbose=False)
raw_clean.set_annotations(raw.annotations)
return raw_clean
# =============================================================================
# Feature extraction helpers
# =============================================================================
def _extract_run_features(raw_clean, raw_theta, raw_alpha, raw_beta,
annotations, task_type):
"""Extract Hjorth features from a single (non-concatenated) run."""
features_list = []
labels_list = []
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
# Pre-fetch full arrays ONCE (avoid repeated get_data() in loop)
all_raw = raw_clean.get_data()
all_theta = raw_theta.get_data()
all_alpha = raw_alpha.get_data()
all_beta = raw_beta.get_data()
for _, annot in annotations.iterrows():
if annot['label'] not in ACTIVE_EVENT_LABELS:
continue
start = int(annot['start_row'])
end = int(annot['end_row'])
end = min(end, all_raw.shape[1])
if start >= end:
continue
data_raw = all_raw[ch_idx_clean, start:end]
data_theta = all_theta[ch_idx_theta, start:end]
data_alpha = all_alpha[ch_idx_alpha, start:end]
data_beta = all_beta[ch_idx_beta, start:end]
features = extract_epoch_features(data_raw, data_theta,
data_alpha, data_beta)
features_list.append(features)
labels_list.append(task_type)
return features_list, labels_list
def _extract_run_features_from_arrays(all_raw, all_theta, all_alpha, all_beta,
ch_idx_clean, ch_idx_theta, ch_idx_alpha,
ch_idx_beta,
annotations, task_type, sample_offset,
max_samples):
"""Extract Hjorth features from one run within pre-fetched concatenated arrays.
This avoids repeated get_data() calls by working directly on numpy arrays.
"""
features_list = []
labels_list = []
for _, annot in annotations.iterrows():
if annot['label'] not in ACTIVE_EVENT_LABELS:
continue
start = sample_offset + int(annot['start_row'])
end = sample_offset + int(annot['end_row'])
end = min(end, max_samples)
if start >= end:
continue
data_raw = all_raw[ch_idx_clean, start:end]
data_theta = all_theta[ch_idx_theta, start:end]
data_alpha = all_alpha[ch_idx_alpha, start:end]
data_beta = all_beta[ch_idx_beta, start:end]
features = extract_epoch_features(data_raw, data_theta,
data_alpha, data_beta)
features_list.append(features)
labels_list.append(task_type)
return features_list, labels_list
# =============================================================================
# Classification (Section 3.6)
# =============================================================================
def classify_with_averaging(X_train, y_train, X_test, y_test,
X_val=None, y_val=None, verbose=True):
"""
Train SVM N_RUNS times with different random seeds and report mean ± std.
Paper: "results were averaged by running the model five times"
Uses MinMaxScaler [0,1] per Equation 8, fitted on training data only.
Returns
-------
results : dict with keys like 'test_acc', 'val_acc', etc.
Each value is a list of N_RUNS floats.
"""
# Normalize (Equation 8)
scaler = MinMaxScaler(feature_range=(0, 1))
X_train_s = scaler.fit_transform(X_train)
X_test_s = scaler.transform(X_test)
X_val_s = scaler.transform(X_val) if X_val is not None else None
results = {
'train_acc': [], 'test_acc': [], 'val_acc': [],
'test_me_recall': [], 'test_mi_recall': [],
'val_me_recall': [], 'val_mi_recall': [],
}
if verbose:
print(f"\n Training SVM {N_RUNS} times (C={SVM_C}, gamma={SVM_GAMMA}, "
f"kernel={SVM_KERNEL}, deflation ICA)")
print(f" {'Run':>5} | {'Seed':>6} | {'Train':>7} | {'Test':>7} | "
f"{'Val':>7} | {'Test ME':>8} | {'Test MI':>8}")
print(f" {'-'*65}")
for run_i, seed in enumerate(RANDOM_SEEDS):
svm = SVC(kernel=SVM_KERNEL, C=SVM_C, gamma=SVM_GAMMA,
random_state=seed)
svm.fit(X_train_s, y_train)
train_acc = accuracy_score(y_train, svm.predict(X_train_s))
test_acc = accuracy_score(y_test, svm.predict(X_test_s))
results['train_acc'].append(train_acc)
results['test_acc'].append(test_acc)
# Per-class recall
test_cm = confusion_matrix(y_test, svm.predict(X_test_s),
labels=['ME', 'MI'])
test_me = test_cm[0, 0] / test_cm[0].sum() if test_cm[0].sum() > 0 else 0
test_mi = test_cm[1, 1] / test_cm[1].sum() if test_cm[1].sum() > 0 else 0
results['test_me_recall'].append(test_me)
results['test_mi_recall'].append(test_mi)
val_acc = 0
val_me = 0
val_mi = 0
if X_val is not None:
val_acc = accuracy_score(y_val, svm.predict(X_val_s))
val_cm = confusion_matrix(y_val, svm.predict(X_val_s),
labels=['ME', 'MI'])
val_me = val_cm[0, 0] / val_cm[0].sum() if val_cm[0].sum() > 0 else 0
val_mi = val_cm[1, 1] / val_cm[1].sum() if val_cm[1].sum() > 0 else 0
results['val_acc'].append(val_acc)
results['val_me_recall'].append(val_me)
results['val_mi_recall'].append(val_mi)
if verbose:
print(f" {run_i+1:>5} | {seed:>6} | {train_acc:>7.2%} | "
f"{test_acc:>7.2%} | {val_acc:>7.2%} | "
f"{test_me:>8.2%} | {test_mi:>8.2%}")
return results
def print_results_summary(results):
"""Print mean ± std summary matching the paper's Table 6/7 format."""
print(f"\n {'Metric':<28s} {'Mean ± Std':>14s}")
print(f" {'-'*44}")
labels = {
'train_acc': 'Training Accuracy',
'test_acc': 'Testing Accuracy',
'val_acc': 'Validation Accuracy',
'test_me_recall': 'Test ME Recall (TP rate)',
'test_mi_recall': 'Test MI Recall (TN rate)',
'val_me_recall': 'Val ME Recall (TP rate)',
'val_mi_recall': 'Val MI Recall (TN rate)',
}
for key, label in labels.items():
vals = np.array(results[key])
print(f" {label:<28s} {vals.mean():>6.2%} ± {vals.std():.2%}")
print(f"\n Paper's Method 2 targets:")
print(f" Overall accuracy: 68.8 ± 0.71%")
print(f" ME recall: 68.17%")
print(f" MI recall: 68.41%")