1
0
Files
Mwata-Velu_et_al_2023/reproduction_notebook.ipynb
2026-04-09 08:21:30 -07:00

246 lines
7.9 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Reproducing: Mwata-Velu et al. (2023)\n",
"\n",
"**Paper**: \"EEG-BCI Features Discrimination between Executed and Imagined Movements Based on FastICA, Hjorth Parameters, and SVM\" \n",
"**Journal**: Mathematics 2023, 11, 4409 \n",
"**DOI**: https://doi.org/10.3390/math11214409\n",
"\n",
"## ICA Strategy\n",
"\n",
"The paper does not specify whether ICA is fitted per-run, per-subject, or globally.\n",
"This notebook supports all three strategies via `config.ICA_STRATEGY`.\n",
"Change it in `config.py` before running.\n",
"\n",
"## Key Implementation Decisions\n",
"\n",
"| Decision | Paper says | We do | Rationale |\n",
"|----------|-----------|-------|----------|\n",
"| Pipeline order | Figure 1: filter→ICA; Algorithm 1: ICA with internal sub-band eval | ICA then sub-band eval | Energy criterion is meaningless on pre-filtered data |\n",
"| ICA algorithm | Gram-Schmidt (Algorithm 1, Step 3) | `algorithm='deflation'` | Deflation uses Gram-Schmidt |\n",
"| Energy criterion | ∀χ ∈ {α, β, **γ**} | ∀χ ∈ {**θ**, α, β} | γ never defined; likely typo for θ |\n",
"| ICA scope | Not specified | Configurable | Reproducibility variable |\n",
"| Classification Method | Methods 1 and 2 | Method 2 only (cross-subject) | Method 1 split is contradictory |"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 1. Setup"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import numpy as np\n",
"import mne\n",
"import warnings\n",
"warnings.filterwarnings('ignore', category=RuntimeWarning) # MNE verbosity\n",
"mne.set_log_level('WARNING')\n",
"\n",
"from config import (\n",
" ICA_STRATEGY, DATA_DIR, TARGET_CHANNELS,\n",
" TRAIN_SUBJECTS, TEST_SUBJECTS, VAL_SUBJECTS, TARGET_RUNS,\n",
")\n",
"from pipeline import (\n",
" process_subject_per_run, process_subject_per_subject,\n",
" process_subject_global, fit_global_ica,\n",
" classify_with_averaging, print_results_summary,\n",
")\n",
"\n",
"print(f\"Data directory: {DATA_DIR}\")\n",
"print(f\"ICA strategy: {ICA_STRATEGY}\")\n",
"print(f\"Target channels: {TARGET_CHANNELS}\")\n",
"print(f\"Subjects: {len(TRAIN_SUBJECTS)} train, {len(TEST_SUBJECTS)} test, {len(VAL_SUBJECTS)} val\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Process All Subjects (Method 2: Cross-Subject Split)\n",
"\n",
"Table 4 from the paper:\n",
"- Training: Subjects 1-83 (80%)\n",
"- Testing: Subjects 84-93 (10%)\n",
"- Validation: Subjects 94-103 (10%)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# If using global ICA, fit it first on training subjects\n",
"global_ica_results = None\n",
"if ICA_STRATEGY == 'global':\n",
" print(\"Fitting global ICA on all training subjects...\")\n",
" global_ica_results = fit_global_ica(TRAIN_SUBJECTS)\n",
" print(\"Global ICA fitting complete.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def process_subject_group(subject_list, group_name):\n",
" \"\"\"Process a group of subjects and collect features/labels.\"\"\"\n",
" all_features = []\n",
" all_labels = []\n",
" total_kept = 0\n",
" n_processed = 0\n",
" \n",
" print(f\"\\n{'='*60}\")\n",
" print(f\"PROCESSING {group_name} ({len(subject_list)} subjects)\")\n",
" print(f\"{'='*60}\")\n",
" \n",
" for subject_num in subject_list:\n",
" try:\n",
" if ICA_STRATEGY == 'global':\n",
" feats, labels, runs, n_kept = process_subject_global(\n",
" subject_num, global_ica_results, verbose=True)\n",
" elif ICA_STRATEGY == 'per_subject':\n",
" feats, labels, runs, n_kept = process_subject_per_subject(\n",
" subject_num, verbose=True)\n",
" elif ICA_STRATEGY == 'per_run':\n",
" feats, labels, runs, n_kept = process_subject_per_run(\n",
" subject_num, verbose=True)\n",
" \n",
" if len(feats) > 0:\n",
" combined_feats = np.vstack(feats)\n",
" combined_labels = np.concatenate(labels)\n",
" all_features.append(combined_feats)\n",
" all_labels.append(combined_labels)\n",
" total_kept += n_kept\n",
" n_processed += 1\n",
" \n",
" n_me = np.sum(combined_labels == 'ME')\n",
" n_mi = np.sum(combined_labels == 'MI')\n",
" print(f\" S{subject_num:03d}: {len(combined_labels)} epochs \"\n",
" f\"(ME={n_me}, MI={n_mi}), {n_kept} ICA components kept\")\n",
" except Exception as e:\n",
" print(f\" S{subject_num:03d}: ERROR - {e}\")\n",
" continue\n",
" \n",
" if len(all_features) > 0:\n",
" X = np.vstack(all_features)\n",
" y = np.concatenate(all_labels)\n",
" else:\n",
" X = np.array([])\n",
" y = np.array([])\n",
" \n",
" print(f\"\\n {group_name} total: {len(y)} samples \"\n",
" f\"(ME={np.sum(y=='ME')}, MI={np.sum(y=='MI')})\")\n",
" if n_processed > 0:\n",
" print(f\" Avg ICA components kept: {total_kept / n_processed:.1f}\")\n",
" \n",
" return X, y"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_train, y_train = process_subject_group(TRAIN_SUBJECTS, \"TRAINING\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_test, y_test = process_subject_group(TEST_SUBJECTS, \"TESTING\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"X_val, y_val = process_subject_group(VAL_SUBJECTS, \"VALIDATION\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Summary\n",
"print(f\"\\n{'='*60}\")\n",
"print(f\"DATASET SUMMARY\")\n",
"print(f\"{'='*60}\")\n",
"print(f\" Training: {X_train.shape} — ME={np.sum(y_train=='ME')}, MI={np.sum(y_train=='MI')}\")\n",
"print(f\" Testing: {X_test.shape} — ME={np.sum(y_test=='ME')}, MI={np.sum(y_test=='MI')}\")\n",
"print(f\" Validation: {X_val.shape} — ME={np.sum(y_val=='ME')}, MI={np.sum(y_val=='MI')}\")\n",
"print(f\"\\n Paper expects: ~6972 train, ~840 test, ~840 val\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 4. Classification with 5-Run Averaging"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"results = classify_with_averaging(\n",
" X_train, y_train,\n",
" X_test, y_test,\n",
" X_val, y_val,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"print_results_summary(results)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "THESIS",
"language": "python",
"name": "thesis"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.23"
}
},
"nbformat": 4,
"nbformat_minor": 4
}