1
0

Initial commit

This commit is contained in:
2026-04-09 08:21:30 -07:00
commit d3ef110c1e
6 changed files with 1489 additions and 0 deletions

910
pipeline.py Normal file
View File

@@ -0,0 +1,910 @@
"""
Core signal processing pipeline for reproducing Mwata-Velu et al. (2023).
Contains:
- Chebyshev Type II sub-band filtering (Section 3.3)
- FastICA denoising with energy concentration criterion (Section 3.4, Algorithm 1)
- Hjorth parameter feature extraction (Section 3.5)
- SVM classification with 5-run averaging (Section 3.6)
PIPELINE ORDER (Algorithm 1 interpretation):
1. ICA on broadband 19-channel data
- Step 6 internally filters each back-projected IC into θ/α/β sub-bands
to evaluate energy concentration in selected channels
- Step 7 keeps only ICs meeting the threshold in ANY sub-band (OR logic).
The paper specifies ALL (AND logic), but this causes zero components
to pass for many subjects. OR logic is used as a documented deviation.
- Step 8 reconstructs cleaned broadband signal from kept ICs
2. Filter the *cleaned* broadband signal into θ/α/β for Hjorth Activity features
3. Compute Hjorth Mobility/Complexity on the *cleaned* unfiltered signal
This resolves the paper's contradiction between Figure 1 and Algorithm 1. The
Algorithm 1 reading is preferred because the energy concentration criterion is
meaningless on pre-filtered data (e.g., filtering theta-band data into alpha
yields ~zero).
"""
import numpy as np
from scipy.signal import cheby2, filtfilt
import mne
from mne.preprocessing import ICA
from sklearn.svm import SVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report
from config import (
ICA_CHANNELS, SELECTED_CHANNELS, SUB_BANDS, TARGET_RUNS,
ICA_N_COMPONENTS, ICA_ENERGY_THRESHOLD, ICA_MAX_ITER, ICA_TOL,
TARGET_CHANNELS, EXECUTION_RUNS, IMAGERY_RUNS, ACTIVE_EVENT_LABELS,
SVM_C, SVM_GAMMA, SVM_KERNEL, N_RUNS, RANDOM_SEEDS,
)
from data_loading import (
load_signal_file, load_annotation_file, create_mne_raw, add_annotations_to_raw,
)
# =============================================================================
# Sub-band filtering (Section 3.3)
# =============================================================================
def design_chebyshev_filter(lowcut, highcut, fs, order=5, rs=34):
"""
Design Chebyshev Type II bandpass filter.
Paper specifies:
- 5th-order Chebyshev Type II
- Stop-band attenuation of -34 dB
- Transition bands reaching 80% of the gain
"""
nyq = 0.5 * fs
low = lowcut / nyq
high = highcut / nyq
b, a = cheby2(order, rs, [low, high], btype='band')
return b, a
def apply_bandpass_filter(raw, lowcut, highcut, picks=None):
"""
Apply Chebyshev Type II bandpass filter.
Parameters
----------
raw : mne.io.Raw
lowcut, highcut : float
picks : list of str, optional
Channel names to filter. If None, filters all channels.
When specified, returns a Raw containing only those channels.
Returns a new Raw object with filtered data.
"""
if picks is not None:
raw_sub = raw.copy().pick(picks)
else:
raw_sub = raw.copy()
b, a = design_chebyshev_filter(lowcut, highcut, raw_sub.info['sfreq'])
data = raw_sub.get_data()
filtered_data = filtfilt(b, a, data, axis=1)
raw_filtered = mne.io.RawArray(filtered_data, raw_sub.info, verbose=False)
raw_filtered.set_annotations(raw.annotations)
return raw_filtered
def _cheby2_bandpass(data, sfreq, low, high, order=5, rs=34.0):
"""
Apply Chebyshev Type II bandpass filter to a numpy array.
Used internally by the ICA energy concentration criterion (Step 6-7).
Parameters
----------
data : ndarray, shape (n_signals, n_samples)
sfreq : float
low, high : float, pass-band edges in Hz
Returns
-------
filtered : ndarray, same shape as data
"""
b, a = design_chebyshev_filter(low, high, sfreq, order=order, rs=rs)
return filtfilt(b, a, data, axis=-1)
# =============================================================================
# FastICA denoising (Section 3.4, Algorithm 1)
# =============================================================================
def apply_fastica_denoising(raw, n_components=ICA_N_COMPONENTS,
energy_threshold=ICA_ENERGY_THRESHOLD,
verbose=True):
"""
Denoise EEG with FastICA following Algorithm 1 of Mwata-Velu et al. (2023).
This function operates on BROADBAND data. Sub-band filtering is used
internally only for the energy concentration criterion (Steps 6-7).
Steps mapped to the paper
-------------------------
Steps 1-5: FastICA decomposition of 19 channels into n_components ICs.
- Step 1: Random weight vector initialization
- Step 2: Fixed-point update with log-cosh non-linearity (Eq. 3-4)
- Step 3: Gram-Schmidt orthogonalization (deflation approach)
- Step 4: Weight vector normalization
- Step 5: Convergence check; repeat 2-4 if not converged
NOTE: We use algorithm='deflation' to match the paper's explicit
description of Gram-Schmidt orthogonalization. sklearn's default
('parallel') uses symmetric decorrelation instead, which is a
different algorithm. See Hyvärinen & Oja (1997, 2000).
Step 6: For every IC k, back-project to 19 channels, filter into each
sub-band, and compute per-channel variance.
Step 7: Retain IC k if, for ANY sub-band, the ratio
sum(Selected) / sum(All) >= threshold (paper: 0.35).
NOTE: The paper specifies AND logic (all sub-bands must pass),
but this is too strict — many subjects yield zero kept components.
We use OR logic (any sub-band passing suffices) as a documented
reproducibility deviation.
NOTE: The paper writes {γ, α, β}" with gamma instead of theta.
The gamma band is never defined in the paper. We interpret this as a
typo and use θ/α/β as defined in Section 3.3.
Step 8: Reconstruct û = A_p @ S_p using only kept components.
Parameters
----------
raw : mne.io.Raw
Must contain (at least) the 19 ICA_CHANNELS.
Should be BROADBAND (unfiltered) data.
n_components : int
Number of ICs to extract. Paper uses 19.
energy_threshold : float
Per-sub-band concentration ratio threshold. Paper: 0.35.
verbose : bool
Print progress information.
Returns
-------
raw_clean : mne.io.Raw
Copy of raw with denoised ICA channels (still broadband).
ica : mne.preprocessing.ICA
Fitted ICA object.
kept_components : list of int
Indices of ICs that passed the Step 7 criterion.
"""
sfreq = raw.info['sfreq']
if verbose:
print(f" [FastICA] {n_components} components, "
f"threshold={energy_threshold}, deflation mode")
# ------------------------------------------------------------------
# Steps 1-5: FastICA decomposition on broadband data
# ------------------------------------------------------------------
raw_ica = raw.copy().pick(ICA_CHANNELS)
ica = ICA(
n_components=n_components,
method='fastica',
fit_params={
'max_iter': ICA_MAX_ITER,
'tol': ICA_TOL,
'algorithm': 'deflation', # Paper's Algorithm 1, Step 3
},
random_state=42,
)
ica.fit(raw_ica)
# Mixing matrix A and IC time-series S
mixing = ica.mixing_matrix_
unmixing = ica.unmixing_matrix_
raw_data = raw_ica.get_data()
icasigs = unmixing @ raw_data
# Channel index maps
all_ch_names = list(raw_ica.ch_names)
sel_idx = [all_ch_names.index(ch) for ch in SELECTED_CHANNELS]
# ------------------------------------------------------------------
# Step 6: Evaluate each IC by sub-band (internal filtering)
#
# For each IC k:
# 1. Back-project to 19 channels: signal[c, t] = mixing[c, k] * IC_k[t]
# 2. Filter the 19-channel signal into each sub-band
# 3. Compute variance per channel
# ------------------------------------------------------------------
ic_channel_var = {name: [] for name, _, _ in SUB_BANDS}
for k in range(n_components):
signal_19ch = mixing[:, k, np.newaxis] * icasigs[k, :] # (19, T)
for name, low, high in SUB_BANDS:
filtered = _cheby2_bandpass(signal_19ch, sfreq, low, high)
var_per_channel = np.var(filtered, axis=1, ddof=1)
ic_channel_var[name].append(var_per_channel)
for name in ic_channel_var:
ic_channel_var[name] = np.array(ic_channel_var[name])
# ------------------------------------------------------------------
# Step 7: Energy concentration criterion
#
# Paper states: ∀χ ∈ {θ, α, β} (AND logic).
# However, AND across all three bands is extremely strict and causes
# zero components to pass for many subjects. We use OR logic instead:
# keep IC k if it passes the threshold in ANY sub-band. This is more
# defensible — a component concentrated in motor cortex in alpha is
# still a neural signal worth keeping even if its theta energy is
# diffuse. Documented as a reproducibility deviation.
# ------------------------------------------------------------------
kept_components = []
if verbose:
print(f" {'IC':>4} | {'theta':>7} | {'alpha':>7} | {'beta':>7} | {'kept?':>5}")
print(f" {'-'*46}")
for k in range(n_components):
ratios = {}
keep = False # OR logic: start False, flip to True if ANY passes
# keep = True # AND logic: start True, flip to False if ANY fails
for name, _, _ in SUB_BANDS:
var_all = ic_channel_var[name][k]
num = var_all[sel_idx].sum()
den = var_all.sum()
ratio = num / den if den > 0 else 0.0
ratios[name] = ratio
if ratio >= energy_threshold:
keep = True # passes in at least one sub-band
# if ratio < energy_threshold:
# keep = False # fails in at least one sub-band
if verbose:
status = "YES" if keep else "no"
print(f" {k:>4} | {ratios['theta']:>7.3f} | "
f"{ratios['alpha']:>7.3f} | {ratios['beta']:>7.3f} | {status:>5}")
if keep:
kept_components.append(k)
if verbose:
print(f" [FastICA] Kept {len(kept_components)}/{n_components} components")
# ------------------------------------------------------------------
# Step 8: Reconstruct û = A_p @ S_p (broadband cleaned signal)
# ------------------------------------------------------------------
if len(kept_components) == 0:
if verbose:
print(" [FastICA] WARNING: no components passed threshold "
"even with OR logic. Returning original data unchanged.")
return raw.copy(), ica, kept_components
A_p = mixing[:, kept_components]
S_p = icasigs[kept_components, :]
reconstructed = A_p @ S_p
raw_clean = mne.io.RawArray(reconstructed, raw_ica.info, verbose=False)
raw_clean.set_annotations(raw.annotations)
return raw_clean, ica, kept_components
# =============================================================================
# Hjorth parameter extraction (Section 3.5)
# =============================================================================
def compute_hjorth_parameters(signal_array):
"""
Compute Hjorth parameters: Activity, Mobility, Complexity.
Activity: Variance of the signal (Eq. 5)
Mobility: sqrt(Var(d/dt signal) / Var(signal)) (Eq. 6)
Complexity: Mobility(d/dt signal) / Mobility(signal) (Eq. 7)
The derivative d/dt u(i) is approximated as u(i) - u(i-1) per the paper.
Parameters
----------
signal_array : np.ndarray, 1D
Returns
-------
activity, mobility, complexity : float
"""
activity = np.var(signal_array)
d1 = np.diff(signal_array)
var_d1 = np.var(d1)
if activity > 0:
mobility = np.sqrt(var_d1 / activity)
else:
mobility = 0.0
d2 = np.diff(d1)
var_d2 = np.var(d2)
if var_d1 > 0 and mobility > 0:
mobility_d1 = np.sqrt(var_d2 / var_d1)
complexity = mobility_d1 / mobility
else:
complexity = 0.0
return activity, mobility, complexity
def extract_epoch_features(data_raw, data_theta, data_alpha, data_beta,
channels=TARGET_CHANNELS):
"""
Extract Hjorth features for one epoch across all target channels.
Feature set matches Table 5, Set 3: Aθ, Aα, Aβ, Mraw, Craw per channel.
Total features = len(channels) * 5.
Parameters
----------
data_raw, data_theta, data_alpha, data_beta : np.ndarray
Shape (n_channels, n_samples) epoch data for each band.
Returns
-------
features : np.ndarray of shape (n_channels * 5,)
"""
features = []
for i in range(len(channels)):
a_theta, _, _ = compute_hjorth_parameters(data_theta[i])
a_alpha, _, _ = compute_hjorth_parameters(data_alpha[i])
a_beta, _, _ = compute_hjorth_parameters(data_beta[i])
_, m_raw, c_raw = compute_hjorth_parameters(data_raw[i])
features.extend([a_theta, a_alpha, a_beta, m_raw, c_raw])
return np.array(features)
# =============================================================================
# Subject processing (orchestration)
#
# In all strategies below, the pipeline is:
# 1. ICA on broadband data (1 fit, not 4)
# 2. Filter the CLEANED broadband signal into θ/α/β for Activity features
# 3. Compute Mobility/Complexity on the CLEANED unfiltered signal
# =============================================================================
def process_subject_per_run(subject_num, target_runs=None, apply_ica=True,
verbose=True):
"""
Process one subject with ICA fitted independently per run.
This gives ICA the least amount of data to work with (~19,200 samples for
a 2-minute run on 19 channels).
Returns
-------
all_features : list of np.ndarray, one per run
all_labels : list of np.ndarray, one per run
run_numbers : list of int
total_kept : int (total across all ICA fits)
"""
if target_runs is None:
target_runs = TARGET_RUNS
all_features = []
all_labels = []
run_numbers = []
total_kept = 0
n_ica_fits = 0
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
annotations = load_annotation_file(subject_num, run_num)
except Exception as e:
if verbose:
print(f" Could not load S{subject_num:03d} R{run_num:02d}: {e}")
continue
raw = create_mne_raw(signal_data)
raw = add_annotations_to_raw(raw, annotations)
# ICA on broadband data (1 fit per run; sub-band eval is internal)
if apply_ica:
raw_clean, _, kept = apply_fastica_denoising(raw, verbose=verbose)
total_kept += len(kept)
n_ica_fits += 1
else:
raw_clean = raw
# Filter the CLEANED signal into sub-bands for Activity features
# Only filter the channels we actually need for feature extraction
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
# Determine task type from run number
if run_num in EXECUTION_RUNS:
task_type = 'ME'
elif run_num in IMAGERY_RUNS:
task_type = 'MI'
else:
continue
# Extract features from active events
features_list, labels_list = _extract_run_features(
raw_clean, raw_theta, raw_alpha, raw_beta,
annotations, task_type
)
if len(features_list) > 0:
all_features.append(np.array(features_list))
all_labels.append(np.array(labels_list))
run_numbers.append(run_num)
return all_features, all_labels, run_numbers, total_kept
def process_subject_per_subject(subject_num, target_runs=None, apply_ica=True,
verbose=True):
"""
Process one subject with ICA fitted once on all runs concatenated.
This gives ICA more data for better decomposition (~115,200 samples
for 12 runs).
Returns
-------
all_features : list of np.ndarray, one per run
all_labels : list of np.ndarray, one per run
run_numbers : list of int
n_kept_components : int
"""
if target_runs is None:
target_runs = TARGET_RUNS
# Step 1: Load and concatenate all runs
raw_runs = []
annotations_runs = []
run_lengths = []
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
annotations = load_annotation_file(subject_num, run_num)
raw = create_mne_raw(signal_data)
raw = add_annotations_to_raw(raw, annotations)
raw_runs.append(raw)
annotations_runs.append(annotations)
run_lengths.append(raw.n_times)
except Exception as e:
if verbose:
print(f" Could not load S{subject_num:03d} R{run_num:02d}: {e}")
continue
if len(raw_runs) == 0:
return [], [], [], 0
raw_concat = mne.concatenate_raws(raw_runs, preload=True)
if verbose:
duration = raw_concat.n_times / raw_concat.info['sfreq']
print(f" S{subject_num:03d}: {len(raw_runs)} runs, {duration:.0f}s")
# Step 2: ICA once on broadband concatenated data
n_kept = 0
if apply_ica:
raw_clean, _, kept = apply_fastica_denoising(raw_concat, verbose=verbose)
n_kept = len(kept)
else:
raw_clean = raw_concat
# Step 3: Filter the CLEANED signal into sub-bands for Activity features
# Only filter the channels we actually need for feature extraction
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
# Step 4: Extract features per run using global sample offsets
all_raw_data = raw_clean.get_data()
all_theta_data = raw_theta.get_data()
all_alpha_data = raw_alpha.get_data()
all_beta_data = raw_beta.get_data()
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
all_features = []
all_labels = []
run_numbers = []
sample_offset = 0
for run_idx, run_num in enumerate(target_runs):
if run_idx >= len(raw_runs):
continue
annotations = annotations_runs[run_idx]
run_length = run_lengths[run_idx]
if run_num in EXECUTION_RUNS:
task_type = 'ME'
elif run_num in IMAGERY_RUNS:
task_type = 'MI'
else:
sample_offset += run_length
continue
features_list, labels_list = _extract_run_features_from_arrays(
all_raw_data, all_theta_data, all_alpha_data, all_beta_data,
ch_idx_clean, ch_idx_theta, ch_idx_alpha, ch_idx_beta,
annotations, task_type, sample_offset, all_raw_data.shape[1]
)
if len(features_list) > 0:
all_features.append(np.array(features_list))
all_labels.append(np.array(labels_list))
run_numbers.append(run_num)
sample_offset += run_length
return all_features, all_labels, run_numbers, n_kept
def process_subject_global(subject_num, global_ica_results,
target_runs=None, verbose=True):
"""
Process one subject using a pre-fitted global ICA.
The global ICA is fitted once on all training subjects' broadband data.
Each subject's data is then projected through the same unmixing matrix.
Parameters
----------
subject_num : int
global_ica_results : dict
Keys: 'mixing', 'unmixing', 'kept_components'
From fit_global_ica().
target_runs : list of int, optional
verbose : bool
Returns
-------
all_features, all_labels, run_numbers, n_kept
"""
if target_runs is None:
target_runs = TARGET_RUNS
# Load and concatenate
raw_runs = []
annotations_runs = []
run_lengths = []
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
annotations = load_annotation_file(subject_num, run_num)
raw = create_mne_raw(signal_data)
raw = add_annotations_to_raw(raw, annotations)
raw_runs.append(raw)
annotations_runs.append(annotations)
run_lengths.append(raw.n_times)
except Exception:
continue
if len(raw_runs) == 0:
return [], [], [], 0
raw_concat = mne.concatenate_raws(raw_runs, preload=True)
# Apply pre-fitted global ICA on broadband data
ica_tuple = (global_ica_results['mixing'],
global_ica_results['unmixing'],
global_ica_results['kept_components'])
raw_clean = _apply_precomputed_ica(raw_concat, ica_tuple)
n_kept = len(global_ica_results['kept_components'])
# Filter the CLEANED signal into sub-bands for Activity features
# Only filter the channels we actually need for feature extraction
raw_theta = apply_bandpass_filter(raw_clean, 4, 8, picks=TARGET_CHANNELS)
raw_alpha = apply_bandpass_filter(raw_clean, 8, 13, picks=TARGET_CHANNELS)
raw_beta = apply_bandpass_filter(raw_clean, 13, 30, picks=TARGET_CHANNELS)
# Extract features per run
all_raw_data = raw_clean.get_data()
all_theta_data = raw_theta.get_data()
all_alpha_data = raw_alpha.get_data()
all_beta_data = raw_beta.get_data()
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
all_features = []
all_labels = []
run_numbers = []
sample_offset = 0
for run_idx, run_num in enumerate(target_runs):
if run_idx >= len(raw_runs):
continue
annotations = annotations_runs[run_idx]
run_length = run_lengths[run_idx]
if run_num in EXECUTION_RUNS:
task_type = 'ME'
elif run_num in IMAGERY_RUNS:
task_type = 'MI'
else:
sample_offset += run_length
continue
features_list, labels_list = _extract_run_features_from_arrays(
all_raw_data, all_theta_data, all_alpha_data, all_beta_data,
ch_idx_clean, ch_idx_theta, ch_idx_alpha, ch_idx_beta,
annotations, task_type, sample_offset, all_raw_data.shape[1]
)
if len(features_list) > 0:
all_features.append(np.array(features_list))
all_labels.append(np.array(labels_list))
run_numbers.append(run_num)
sample_offset += run_length
return all_features, all_labels, run_numbers, n_kept
def fit_global_ica(subject_list, target_runs=None, verbose=True):
"""
Fit ICA once on all subjects' broadband data concatenated.
Returns
-------
global_ica_results : dict with keys 'mixing', 'unmixing', 'kept_components'
"""
if target_runs is None:
target_runs = TARGET_RUNS
all_raws = []
for subject_num in subject_list:
for run_num in target_runs:
try:
signal_data = load_signal_file(subject_num, run_num)
raw = create_mne_raw(signal_data)
all_raws.append(raw)
except Exception:
continue
if len(all_raws) == 0:
raise ValueError("No data loaded for global ICA")
if verbose:
print(f" Global ICA: concatenating {len(all_raws)} runs...")
global_raw = mne.concatenate_raws(all_raws, preload=True)
if verbose:
duration = global_raw.n_times / global_raw.info['sfreq']
print(f" Global ICA: {duration:.0f}s total data")
# Fit ICA once on broadband data
if verbose:
print(f" Fitting global ICA on broadband data...")
_, ica, kept = apply_fastica_denoising(global_raw, verbose=verbose)
return {
'mixing': ica.mixing_matrix_,
'unmixing': ica.unmixing_matrix_,
'kept_components': kept,
}
def _apply_precomputed_ica(raw, ica_result):
"""
Apply a pre-fitted ICA decomposition to reconstruct cleaned signals.
Parameters
----------
raw : mne.io.Raw
ica_result : tuple (mixing, unmixing, kept_components)
"""
mixing, unmixing, kept_components = ica_result
if len(kept_components) == 0:
return raw.copy()
raw_ica = raw.copy().pick(ICA_CHANNELS)
raw_data = raw_ica.get_data()
icasigs = unmixing @ raw_data
A_p = mixing[:, kept_components]
S_p = icasigs[kept_components, :]
reconstructed = A_p @ S_p
raw_clean = mne.io.RawArray(reconstructed, raw_ica.info, verbose=False)
raw_clean.set_annotations(raw.annotations)
return raw_clean
# =============================================================================
# Feature extraction helpers
# =============================================================================
def _extract_run_features(raw_clean, raw_theta, raw_alpha, raw_beta,
annotations, task_type):
"""Extract Hjorth features from a single (non-concatenated) run."""
features_list = []
labels_list = []
ch_idx_clean = [raw_clean.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_theta = [raw_theta.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_alpha = [raw_alpha.ch_names.index(ch) for ch in TARGET_CHANNELS]
ch_idx_beta = [raw_beta.ch_names.index(ch) for ch in TARGET_CHANNELS]
# Pre-fetch full arrays ONCE (avoid repeated get_data() in loop)
all_raw = raw_clean.get_data()
all_theta = raw_theta.get_data()
all_alpha = raw_alpha.get_data()
all_beta = raw_beta.get_data()
for _, annot in annotations.iterrows():
if annot['label'] not in ACTIVE_EVENT_LABELS:
continue
start = int(annot['start_row'])
end = int(annot['end_row'])
end = min(end, all_raw.shape[1])
if start >= end:
continue
data_raw = all_raw[ch_idx_clean, start:end]
data_theta = all_theta[ch_idx_theta, start:end]
data_alpha = all_alpha[ch_idx_alpha, start:end]
data_beta = all_beta[ch_idx_beta, start:end]
features = extract_epoch_features(data_raw, data_theta,
data_alpha, data_beta)
features_list.append(features)
labels_list.append(task_type)
return features_list, labels_list
def _extract_run_features_from_arrays(all_raw, all_theta, all_alpha, all_beta,
ch_idx_clean, ch_idx_theta, ch_idx_alpha,
ch_idx_beta,
annotations, task_type, sample_offset,
max_samples):
"""Extract Hjorth features from one run within pre-fetched concatenated arrays.
This avoids repeated get_data() calls by working directly on numpy arrays.
"""
features_list = []
labels_list = []
for _, annot in annotations.iterrows():
if annot['label'] not in ACTIVE_EVENT_LABELS:
continue
start = sample_offset + int(annot['start_row'])
end = sample_offset + int(annot['end_row'])
end = min(end, max_samples)
if start >= end:
continue
data_raw = all_raw[ch_idx_clean, start:end]
data_theta = all_theta[ch_idx_theta, start:end]
data_alpha = all_alpha[ch_idx_alpha, start:end]
data_beta = all_beta[ch_idx_beta, start:end]
features = extract_epoch_features(data_raw, data_theta,
data_alpha, data_beta)
features_list.append(features)
labels_list.append(task_type)
return features_list, labels_list
# =============================================================================
# Classification (Section 3.6)
# =============================================================================
def classify_with_averaging(X_train, y_train, X_test, y_test,
X_val=None, y_val=None, verbose=True):
"""
Train SVM N_RUNS times with different random seeds and report mean ± std.
Paper: "results were averaged by running the model five times"
Uses MinMaxScaler [0,1] per Equation 8, fitted on training data only.
Returns
-------
results : dict with keys like 'test_acc', 'val_acc', etc.
Each value is a list of N_RUNS floats.
"""
# Normalize (Equation 8)
scaler = MinMaxScaler(feature_range=(0, 1))
X_train_s = scaler.fit_transform(X_train)
X_test_s = scaler.transform(X_test)
X_val_s = scaler.transform(X_val) if X_val is not None else None
results = {
'train_acc': [], 'test_acc': [], 'val_acc': [],
'test_me_recall': [], 'test_mi_recall': [],
'val_me_recall': [], 'val_mi_recall': [],
}
if verbose:
print(f"\n Training SVM {N_RUNS} times (C={SVM_C}, gamma={SVM_GAMMA}, "
f"kernel={SVM_KERNEL}, deflation ICA)")
print(f" {'Run':>5} | {'Seed':>6} | {'Train':>7} | {'Test':>7} | "
f"{'Val':>7} | {'Test ME':>8} | {'Test MI':>8}")
print(f" {'-'*65}")
for run_i, seed in enumerate(RANDOM_SEEDS):
svm = SVC(kernel=SVM_KERNEL, C=SVM_C, gamma=SVM_GAMMA,
random_state=seed)
svm.fit(X_train_s, y_train)
train_acc = accuracy_score(y_train, svm.predict(X_train_s))
test_acc = accuracy_score(y_test, svm.predict(X_test_s))
results['train_acc'].append(train_acc)
results['test_acc'].append(test_acc)
# Per-class recall
test_cm = confusion_matrix(y_test, svm.predict(X_test_s),
labels=['ME', 'MI'])
test_me = test_cm[0, 0] / test_cm[0].sum() if test_cm[0].sum() > 0 else 0
test_mi = test_cm[1, 1] / test_cm[1].sum() if test_cm[1].sum() > 0 else 0
results['test_me_recall'].append(test_me)
results['test_mi_recall'].append(test_mi)
val_acc = 0
val_me = 0
val_mi = 0
if X_val is not None:
val_acc = accuracy_score(y_val, svm.predict(X_val_s))
val_cm = confusion_matrix(y_val, svm.predict(X_val_s),
labels=['ME', 'MI'])
val_me = val_cm[0, 0] / val_cm[0].sum() if val_cm[0].sum() > 0 else 0
val_mi = val_cm[1, 1] / val_cm[1].sum() if val_cm[1].sum() > 0 else 0
results['val_acc'].append(val_acc)
results['val_me_recall'].append(val_me)
results['val_mi_recall'].append(val_mi)
if verbose:
print(f" {run_i+1:>5} | {seed:>6} | {train_acc:>7.2%} | "
f"{test_acc:>7.2%} | {val_acc:>7.2%} | "
f"{test_me:>8.2%} | {test_mi:>8.2%}")
return results
def print_results_summary(results):
"""Print mean ± std summary matching the paper's Table 6/7 format."""
print(f"\n {'Metric':<28s} {'Mean ± Std':>14s}")
print(f" {'-'*44}")
labels = {
'train_acc': 'Training Accuracy',
'test_acc': 'Testing Accuracy',
'val_acc': 'Validation Accuracy',
'test_me_recall': 'Test ME Recall (TP rate)',
'test_mi_recall': 'Test MI Recall (TN rate)',
'val_me_recall': 'Val ME Recall (TP rate)',
'val_mi_recall': 'Val MI Recall (TN rate)',
}
for key, label in labels.items():
vals = np.array(results[key])
print(f" {label:<28s} {vals.mean():>6.2%} ± {vals.std():.2%}")
print(f"\n Paper's Method 2 targets:")
print(f" Overall accuracy: 68.8 ± 0.71%")
print(f" ME recall: 68.17%")
print(f" MI recall: 68.41%")