Code to replicate the experimental results from Simple data balancing baselines achieve competitive worst-group-accuracy.
Easiest way to have a working environment for this repo is to create a conda environement with the following commands
conda env create -f environment.yaml
conda activate balancinggroups
If conda is not available, please install the dependencies listed in the requirements.txt file.
This script downloads, extracts and formats the datasets metadata so that it works with the rest of the code out of the box.
python setup_datasets.py --download --data_path data
To reproduce the experiments in the paper on a SLURM cluster :
# Launching 1400 combo seeds = 50 hparams for 4 datasets for 7 algorithms
# Each combo seed is ran 5 times to compute error bars, totalling 7000 jobs
python train.py --data_path data --output_dir main_sweep --num_hparams_seeds 1400 --num_init_seeds 5 --partition <slurm_partition>
If you want to run the jobs localy, omit the --partition argument.
The parse.py script can generate all of the plots and tables from the paper. By default, it generates the best test worst-group-accuracy table for each dataset/method. This script can be called while the experiments are still running.
python parse.py main_sweep
This source code is released under the CC-BY-NC license, included here.