The code is based on SIN
We tested the implementation in Python 3.8.
requirements.txt
is an automatically generated file with all dependencies.
Essential packages include:
rdkit
numpy
networkx
scikit-learn
torch
torch-geometric
wandb
The TCGA simulation requires the TCGA and QM9 datasets. The code automatically downloads and unzips these datasets if
they do not exist. Alternatively, the TCGA dataset can be downloaded
from here and the QM9 dataset
from here. Both datasets should be located
in data/tcga/
.
There are three runnable python scripts:
generate_data.py
: Generates and saves a dataset given the configuration inconfigs/generate_data/
.- Stores generated data in
data_path
with folder structure{data_path}/{task}/seed-{seed}/bias-{bias}/
- For each
task
,seed
, andbias
combination, generates and stores a new dataset
- Stores generated data in
run_model_training.py
: Trains and evaluates a CATE estimation model given the configuration inconfigs/run_model/
.- Evaluation results will be logged, can be saved to
results_path
and/or synced to a wandb.ai account
- Evaluation results will be logged, can be saved to
run_hyperparameter_sweeping.py
Sweeps hyper-parameters withwandb
as specified inconfigs/sweeps/
run_unseen_treatment_update.py
: Runs the GNN baseline on a specified dataset and updates one-hot encodings of previously unseen treatments in the test set to the closest ones seen during training based on their Euclidean space in the hidden embedding space.- Before running the CAT baseline, run this script. Otherwise, unseen treatment one-hot encodings will be fed into the network.
task
: Simulationsw
ortcga
bias
: Treatment selection bias coefficientseed
: Random seeddata_path
: Path to save/load generated datasets
task
: Simulationsw
ortcga
model
:gin
,gnn
,cat
,graphite
,zero
,TransTEE
bias
: Treatment selection bias coefficientseed
: Random seed
When parsing smiles from the QM9 dataset for simulating a TCGA experiment, there may be bad input
warnings for certain
molecules. The data generator will ignore these molecules. When subsampling 10k molecules, we noticed that there are
around ~1% faulty molecules.