-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c9d847e
commit 1dbc8b4
Showing
16 changed files
with
2,870 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
name: pre-commit-codestyle | ||
|
||
concurrency: | ||
group: ${{ github.workflow }}-${{ github.event.number }}-${{ github.event.ref }} | ||
cancel-in-progress: true | ||
|
||
on: [pull_request] | ||
|
||
jobs: | ||
|
||
pre-commit: | ||
runs-on: ubuntu-latest | ||
steps: | ||
- uses: actions/checkout@v4 | ||
- uses: actions/setup-python@v5 | ||
- uses: pre-commit/[email protected] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
name: Prettier code formatter | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- master | ||
- main | ||
push: | ||
branches: | ||
- master | ||
- main | ||
|
||
jobs: | ||
check: | ||
# available images: https://github.com/actions/runner-images#available-images | ||
runs-on: ubuntu-latest | ||
steps: | ||
- name: Checkout 🛎️ | ||
uses: actions/checkout@v4 | ||
- name: Setup Node.js ⚙️ | ||
uses: actions/setup-node@v4 | ||
- name: Install Prettier 💾 | ||
run: npm install --save-dev --save-exact prettier @shopify/prettier-plugin-liquid | ||
- name: Prettier Check 🔎 | ||
id: prettier | ||
run: npx prettier . --check | ||
- name: Create diff 📝 | ||
# https://docs.github.com/en/actions/learn-github-actions/expressions#failure | ||
if: ${{ failure() }} | ||
run: | | ||
npx prettier . --write | ||
git diff -- . ':(exclude)package-lock.json' ':(exclude)package.json' > diff.txt | ||
npm install -g diff2html-cli | ||
diff2html -i file -s side -F diff.html -- diff.txt | ||
- name: Upload html diff ⬆️ | ||
id: artifact-upload | ||
if: ${{ failure() && steps.prettier.conclusion == 'failure' }} | ||
uses: actions/upload-artifact@v4 | ||
with: | ||
name: HTML Diff | ||
path: diff.html | ||
retention-days: 7 | ||
- name: Dispatch information to repository 🗣️ | ||
if: ${{ failure() && steps.prettier.conclusion == 'failure' && github.event_name == 'pull_request' }} | ||
uses: peter-evans/repository-dispatch@v2 | ||
with: | ||
event-type: prettier-failed-on-pr | ||
client-payload: '{"pr_number": "${{ github.event.number }}", "artifact_url": "${{ steps.artifact-upload.outputs.artifact-url }}", "run_id": "${{ github.run_id }}"}' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,64 @@ | ||
default_language_version: | ||
python: python3 | ||
|
||
ci: | ||
autofix_commit_msg: | | ||
[pre-commit.ci] auto fixes from pre-commit.com hooks | ||
autofix_prs: true | ||
autoupdate_branch: "master" | ||
autoupdate_commit_msg: "[pre-commit.ci] pre-commit autoupdate" | ||
autoupdate_schedule: quarterly | ||
skip: [ ] | ||
submodules: false | ||
|
||
repos: | ||
- repo: https://github.com/pre-commit/pre-commit-hooks | ||
rev: v4.6.0 | ||
hooks: | ||
- id: check-yaml | ||
- id: check-json | ||
- id: check-executables-have-shebangs | ||
- id: check-toml | ||
- id: check-docstring-first | ||
- id: check-added-large-files | ||
- id: end-of-file-fixer | ||
- id: trailing-whitespace | ||
- id: check-case-conflict | ||
- id: mixed-line-ending | ||
- id: end-of-file-fixer | ||
- id: check-case-conflict | ||
- id: forbid-new-submodules | ||
- id: pretty-format-json | ||
args: [ "--autofix", "--no-sort-keys", "--indent=4" ] | ||
|
||
- repo: https://github.com/charliermarsh/ruff-pre-commit | ||
rev: v0.4.8 | ||
hooks: | ||
- id: ruff | ||
args: [ "--ignore=E402,E501,F401", "--fix" ] #, --exit-non-zero-on-fix, ] | ||
- id: ruff | ||
name: ruff lint data notebooks | ||
args: [ "--fix", "--preview", "--select=NPY201" ] | ||
- id: ruff | ||
args: [ "check", "--select", "I", "--fix"] | ||
- id: ruff-format | ||
types_or: [ python, pyi, jupyter ] | ||
|
||
|
||
|
||
- repo: https://github.com/codespell-project/codespell | ||
rev: v2.3.0 | ||
hooks: | ||
- id: codespell | ||
args: | ||
- --ignore-words-list=metadat,splitted,meaned,wil,whats,additionals,alle,alot,bund,currenty,datas,farenheit,falsy,fo,haa,hass,iif,incomfort,ines,ist,nam,nd,pres,pullrequests,resset,rime,ser,serie,te,technik,ue,unsecure,withing,zar,MAPE,mape | ||
- --skip="./.*,*.csv,*.json,*.ambr" | ||
- --quiet-level=2 | ||
exclude_types: [ csv, json ] | ||
exclude: ^tests/|generated/^.github | ||
|
||
- repo: https://github.com/asottile/blacken-docs | ||
rev: 1.16.0 | ||
hooks: | ||
- id: blacken-docs | ||
exclude: ^.github |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
# About | ||
Code for ["General-Purpose Brain Foundation Models for Time-Series Neuroimaging Data"](https://openreview.net/forum?id=HwDQH0r37I) | ||
|
||
# Getting Started | ||
## 0. Install the requirements | ||
To install the requirements, run the following command: | ||
```bash | ||
pip install -r requirements.txt | ||
``` | ||
## 1. Download and preprocess the data | ||
Download the NMT data from [here](https://ilabel.ai/datasets/Nust-Millitary-Hospital-TUKl-NMT-EEG-Dataset/NMT-Scalp-EEG.zip) and extract it to the `data` folder. or you can use the following command: | ||
```bash | ||
wget https://ilabel.ai/datasets/Nust-Millitary-Hospital-TUKl-NMT-EEG-Dataset/NMT-Scalp-EEG.zip | ||
|
||
unzip NMT-Scalp-EEG.zip -d data | ||
``` | ||
or you can use the following command: | ||
```bash | ||
gdown 'https://drive.google.com/uc?id=1jD_AcmfoaIfkOiO5lSU4J6IxHZtalnTk' | ||
|
||
unzip NMT.zip -d data/NMT/ | ||
``` | ||
|
||
## 2. Preprocess the data | ||
To preprocess the data, run the following command: | ||
```bash | ||
python ./data/preprocess.py \ | ||
--dataset nmt \ | ||
--start_range 0 \ | ||
--end_range 500 \ | ||
--exp_path ./data/NMT/NMT_dl/ \ | ||
--nmt_raw_path ./data/NMT/nmt_scalp_eeg_dataset/ | ||
``` | ||
It will preprocess the data and save it as .arrow files in the `data/NMT/nmt_dl/` folder. | ||
|
||
## 3. Train the model | ||
To train the model, run the following command: | ||
```bash | ||
accelerate launch bfm/train/train.py \ | ||
--config bfm/configs/bfm-t5-base-nmt.yaml \ | ||
--experiment-name "bfm-base" \ | ||
--wandb-mode online \ | ||
--wandb-entity <your_wandb_entity> \ | ||
--model-id google/t5-efficient-base \ | ||
--seed 6 \ | ||
--learning-rate 0.001 \ | ||
--per-device-train-batch-size 32 \ | ||
--no-random-init \ | ||
--n-gpus 4 \ | ||
--max-steps 2000 | ||
``` | ||
This will train the model on the NMT dataset using the T5-base model. You can modify the config file to use a different model or dataset. | ||
|
||
## 4. Evaluate the model | ||
To evaluate the model, run the following command: | ||
```bash | ||
CUDA_VISIBLE_DEVICES=0 python bfm/evaluate/evaluate.py \ | ||
--config_path "bfm/configs/bfm-inference.yaml" \ | ||
--directory_path "./bfm/Experiments/bfm-base_nmt" \ | ||
--seed 2024 \ | ||
--device "cuda" | ||
``` | ||
[Note:] You can also use 'data/download_moabb_datasets.py' to download the MOABB datasets. Then you can use 'data/preprocess_moabb.py' to preprocess the MOABB datasets and evaluate the model on them. | ||
|
||
# Citation | ||
If you find this code useful, please consider citing our paper: | ||
``` | ||
@inproceedings{ | ||
bayazi2024generalpurpose, | ||
title={General-Purpose Brain Foundation Models for Time-Series Neuroimaging Data}, | ||
author={Mohammad Javad Darvishi Bayazi and Hena Ghonia and Roland Riachi and Bruno Aristimunha and Arian Khorasani and Md Rifat Arefin and Amin Darabi and Guillaume Dumas and Irina Rish}, | ||
booktitle={NeurIPS Workshop on Time Series in the Age of Large Models}, | ||
year={2024}, | ||
url={https://openreview.net/forum?id=HwDQH0r37I} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
- name: BNCI2014_001 | ||
test_path: ./data/moabb/BNCI2014_001_dl/train/ | ||
num_rolls: 1 | ||
- name: BNCI2014_004 | ||
test_path: ./data/moabb/BNCI2014_004_dl/train/ | ||
num_rolls: 1 | ||
- name: BNCI2015_001 | ||
test_path: ./data/moabb/BNCI2015_001_dl/train/ | ||
num_rolls: 1 | ||
- name: Weibo2014 | ||
test_path: ./data/moabb/Weibo2014_dl/train/ | ||
num_rolls: 1 | ||
- name: Cho2017 | ||
test_path: ./data/moabb/Cho2017_dl/train/ | ||
num_rolls: 1 | ||
- name: Liu2024 | ||
test_path: ./data/moabb/Liu2024_dl/train/ | ||
num_rolls: 1 | ||
- name: PhysionetMI | ||
test_path: ./data/moabb/PhysionetMI_dl/train/ | ||
num_rolls: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
- name: nmt_oodomain | ||
test_path: ./data/NMT/NMT_dl/nmt_dl/test/ | ||
num_rolls: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
training_data_paths: | ||
- ./data/NMT/NMT_dl/nmt_dl/train/ | ||
validation_data_paths: | ||
- ./data/NMT/NMT_dl/nmt_dl/val/ | ||
dataset: "nmt" | ||
project_dir: "./bfm/" | ||
experiment_name: "test_experiment_v0" | ||
wandb_mode: "offline" | ||
n_gpus: 8 | ||
wandb_entity: "brain_fomo" | ||
wandb_project: | ||
context_length: 512 | ||
prediction_length: 64 | ||
min_past: 60 | ||
max_steps: 200_000 | ||
save_steps: 50 | ||
log_steps: 50 | ||
per_device_train_batch_size: 32 | ||
learning_rate: 0.001 | ||
optim: adamw_torch | ||
num_samples: 20 | ||
shuffle_buffer_length: 100 | ||
gradient_accumulation_steps: 1 | ||
model_id: google/t5-efficient-base | ||
model_type: seq2seq | ||
random_init: true | ||
tie_embeddings: true | ||
output_dir: ./output/ | ||
tf32: true | ||
torch_compile: true | ||
tokenizer_class: "MeanScaleUniformBins" | ||
tokenizer_kwargs: | ||
low_limit: -15.0 | ||
high_limit: 15.0 | ||
n_tokens: 4096 | ||
lr_scheduler_type: linear | ||
warmup_ratio: 0.0 | ||
dataloader_num_workers: 1 | ||
max_missing_prop: 0.9 | ||
use_eos_token: true |
Oops, something went wrong.