246 lines
7.9 KiB
Plaintext
246 lines
7.9 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# Reproducing: Mwata-Velu et al. (2023)\n",
|
||
"\n",
|
||
"**Paper**: \"EEG-BCI Features Discrimination between Executed and Imagined Movements Based on FastICA, Hjorth Parameters, and SVM\" \n",
|
||
"**Journal**: Mathematics 2023, 11, 4409 \n",
|
||
"**DOI**: https://doi.org/10.3390/math11214409\n",
|
||
"\n",
|
||
"## ICA Strategy\n",
|
||
"\n",
|
||
"The paper does not specify whether ICA is fitted per-run, per-subject, or globally.\n",
|
||
"This notebook supports all three strategies via `config.ICA_STRATEGY`.\n",
|
||
"Change it in `config.py` before running.\n",
|
||
"\n",
|
||
"## Key Implementation Decisions\n",
|
||
"\n",
|
||
"| Decision | Paper says | We do | Rationale |\n",
|
||
"|----------|-----------|-------|----------|\n",
|
||
"| Pipeline order | Figure 1: filter→ICA; Algorithm 1: ICA with internal sub-band eval | ICA then sub-band eval | Energy criterion is meaningless on pre-filtered data |\n",
|
||
"| ICA algorithm | Gram-Schmidt (Algorithm 1, Step 3) | `algorithm='deflation'` | Deflation uses Gram-Schmidt |\n",
|
||
"| Energy criterion | ∀χ ∈ {α, β, **γ**} | ∀χ ∈ {**θ**, α, β} | γ never defined; likely typo for θ |\n",
|
||
"| ICA scope | Not specified | Configurable | Reproducibility variable |\n",
|
||
"| Classification Method | Methods 1 and 2 | Method 2 only (cross-subject) | Method 1 split is contradictory |"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 1. Setup"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"import mne\n",
|
||
"import warnings\n",
|
||
"warnings.filterwarnings('ignore', category=RuntimeWarning) # MNE verbosity\n",
|
||
"mne.set_log_level('WARNING')\n",
|
||
"\n",
|
||
"from config import (\n",
|
||
" ICA_STRATEGY, DATA_DIR, TARGET_CHANNELS,\n",
|
||
" TRAIN_SUBJECTS, TEST_SUBJECTS, VAL_SUBJECTS, TARGET_RUNS,\n",
|
||
")\n",
|
||
"from pipeline import (\n",
|
||
" process_subject_per_run, process_subject_per_subject,\n",
|
||
" process_subject_global, fit_global_ica,\n",
|
||
" classify_with_averaging, print_results_summary,\n",
|
||
")\n",
|
||
"\n",
|
||
"print(f\"Data directory: {DATA_DIR}\")\n",
|
||
"print(f\"ICA strategy: {ICA_STRATEGY}\")\n",
|
||
"print(f\"Target channels: {TARGET_CHANNELS}\")\n",
|
||
"print(f\"Subjects: {len(TRAIN_SUBJECTS)} train, {len(TEST_SUBJECTS)} test, {len(VAL_SUBJECTS)} val\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 2. Process All Subjects (Method 2: Cross-Subject Split)\n",
|
||
"\n",
|
||
"Table 4 from the paper:\n",
|
||
"- Training: Subjects 1-83 (80%)\n",
|
||
"- Testing: Subjects 84-93 (10%)\n",
|
||
"- Validation: Subjects 94-103 (10%)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# If using global ICA, fit it first on training subjects\n",
|
||
"global_ica_results = None\n",
|
||
"if ICA_STRATEGY == 'global':\n",
|
||
" print(\"Fitting global ICA on all training subjects...\")\n",
|
||
" global_ica_results = fit_global_ica(TRAIN_SUBJECTS)\n",
|
||
" print(\"Global ICA fitting complete.\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"def process_subject_group(subject_list, group_name):\n",
|
||
" \"\"\"Process a group of subjects and collect features/labels.\"\"\"\n",
|
||
" all_features = []\n",
|
||
" all_labels = []\n",
|
||
" total_kept = 0\n",
|
||
" n_processed = 0\n",
|
||
" \n",
|
||
" print(f\"\\n{'='*60}\")\n",
|
||
" print(f\"PROCESSING {group_name} ({len(subject_list)} subjects)\")\n",
|
||
" print(f\"{'='*60}\")\n",
|
||
" \n",
|
||
" for subject_num in subject_list:\n",
|
||
" try:\n",
|
||
" if ICA_STRATEGY == 'global':\n",
|
||
" feats, labels, runs, n_kept = process_subject_global(\n",
|
||
" subject_num, global_ica_results, verbose=True)\n",
|
||
" elif ICA_STRATEGY == 'per_subject':\n",
|
||
" feats, labels, runs, n_kept = process_subject_per_subject(\n",
|
||
" subject_num, verbose=True)\n",
|
||
" elif ICA_STRATEGY == 'per_run':\n",
|
||
" feats, labels, runs, n_kept = process_subject_per_run(\n",
|
||
" subject_num, verbose=True)\n",
|
||
" \n",
|
||
" if len(feats) > 0:\n",
|
||
" combined_feats = np.vstack(feats)\n",
|
||
" combined_labels = np.concatenate(labels)\n",
|
||
" all_features.append(combined_feats)\n",
|
||
" all_labels.append(combined_labels)\n",
|
||
" total_kept += n_kept\n",
|
||
" n_processed += 1\n",
|
||
" \n",
|
||
" n_me = np.sum(combined_labels == 'ME')\n",
|
||
" n_mi = np.sum(combined_labels == 'MI')\n",
|
||
" print(f\" S{subject_num:03d}: {len(combined_labels)} epochs \"\n",
|
||
" f\"(ME={n_me}, MI={n_mi}), {n_kept} ICA components kept\")\n",
|
||
" except Exception as e:\n",
|
||
" print(f\" S{subject_num:03d}: ERROR - {e}\")\n",
|
||
" continue\n",
|
||
" \n",
|
||
" if len(all_features) > 0:\n",
|
||
" X = np.vstack(all_features)\n",
|
||
" y = np.concatenate(all_labels)\n",
|
||
" else:\n",
|
||
" X = np.array([])\n",
|
||
" y = np.array([])\n",
|
||
" \n",
|
||
" print(f\"\\n {group_name} total: {len(y)} samples \"\n",
|
||
" f\"(ME={np.sum(y=='ME')}, MI={np.sum(y=='MI')})\")\n",
|
||
" if n_processed > 0:\n",
|
||
" print(f\" Avg ICA components kept: {total_kept / n_processed:.1f}\")\n",
|
||
" \n",
|
||
" return X, y"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_train, y_train = process_subject_group(TRAIN_SUBJECTS, \"TRAINING\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_test, y_test = process_subject_group(TEST_SUBJECTS, \"TESTING\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"X_val, y_val = process_subject_group(VAL_SUBJECTS, \"VALIDATION\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# Summary\n",
|
||
"print(f\"\\n{'='*60}\")\n",
|
||
"print(f\"DATASET SUMMARY\")\n",
|
||
"print(f\"{'='*60}\")\n",
|
||
"print(f\" Training: {X_train.shape} — ME={np.sum(y_train=='ME')}, MI={np.sum(y_train=='MI')}\")\n",
|
||
"print(f\" Testing: {X_test.shape} — ME={np.sum(y_test=='ME')}, MI={np.sum(y_test=='MI')}\")\n",
|
||
"print(f\" Validation: {X_val.shape} — ME={np.sum(y_val=='ME')}, MI={np.sum(y_val=='MI')}\")\n",
|
||
"print(f\"\\n Paper expects: ~6972 train, ~840 test, ~840 val\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 4. Classification with 5-Run Averaging"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"results = classify_with_averaging(\n",
|
||
" X_train, y_train,\n",
|
||
" X_test, y_test,\n",
|
||
" X_val, y_val,\n",
|
||
")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"print_results_summary(results)"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "THESIS",
|
||
"language": "python",
|
||
"name": "thesis"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.9.23"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|