Implementations of paper Controlling Directions Orthogonal to a Classifier , ICLR 2022 (Spotlight), Yilun Xu, Hao He, Tianxiao Shen, Tommi Jaakkola
[Slide]
Let's construct orthogonal classifiers for controlled style transfer, domain adaptation with label shifts and fairness problems 🤠 !
-
- Prepare Celeba-GH dataset
- Train classifiers and CycleGAN
python style_transfer/celeba_dataset.py --data_dir {path}
path: path to the CelebA dataset
bash example: python style_transfer/celeba_dataset.py --data_dir ./data
One can modify the domain_fn
dictionary in the style_transfer/celeba_dataset.py
file to create new groups 💡
sh style_transfer/train_classifiers.sh {gpu} {path} {dataset} {alg}
gpu: the number of gpu
path: path to the dataset (Celeba or MNIST)
dataset: dataset (Celeba | CMNIST)
alg: ERM, Fish, TRM or MLDG
CMNIST bash example: sh style_transfer/train_classifiers.sh 0 ./data CMNIST ERM
python style_transfer/train_cyclegan.py --data_dir {path} --dataset {dataset} \
--obj {obj} --name {name}
path: path to the dataset (Celeba or MNIST)
dataset: dataset (Celeba | CMNIST)
obj: training objective (vanilla | orthogonal)
name: name of the model
CMNIST bash example: python style_transfer/train_cyclegan.py --data_dir ./data --dataset CMNIST --obj orthogonal --name cmnist
To view training results and loss plots, run python -m visdom.server
and click the URL http://localhost:8097
python style_transfer/generate.py --data_dir {path} --dataset {dataset} --name {name} \
--obj {obj} --out_path {out_path} --resume_epoch {epoch} (--save)
path: path to the dataset (Celeba or MNIST)
dataset: dataset (Celeba | CMNIST)
name: name of the model
obj: training objective (vanilla | orthogonal)
out_path: output path
epoch: resuming epoch of checkpoint
Images will be save to style_transfer/generated_images/out_path
CMNIST bash example: python style_transfer/generate.py --data_dir ./data --dataset CMNIST --name cmnist --obj orthogonal --out_path cmnist_out --resume_epoch 5
Please cd /da/data
and run
python {dataset}.py --r {r0} {r1}
r0: subsample ratio for the first half classes (default=0.7)
r1: subsample ratio for the first half classes (default=0.3)
dataset: mnist | mnistm | svhn | cifar | stl | signs | digits
For SynthDigits / SynthSignsdataset, please download them at link_digits / link_signs. All the other datasets will be automatically downloaded 😉
python da/vada_train.py --r {r0} {r1} --src {source} --tgt {target} --seed {seed} \
(--iw) (--orthogonal) (--source_only)
r0: subsample ratio for the first half classes (default=0.7)
r1: subsample ratio for the first half classes (default=0.3)
source: source domain (mnist | mnistm | svhn | cifar | stl | signs | digits)
target: target domain (mnist | mnistm | svhn | cifar | stl | signs | digits)
seed: random seed
--source_only: vanilla ERM on the source domain
--iw: use importance-weighted domain adaptation algorithm [1]
--orthogonal: use orthogonal classifier
--vada: vanilla VADA [2]
python fairness/methods/train.py --data {data} --gamma {gamma} --sigma {sigma} \
(--orthogonal) (--laftr) (--mifr) (--hsic)
data: dataset (adult | german)
gamma: hyper-parameter for MIFR, HSIC, LAFTR
sigma: hyper-parameter for HSIC (kernel width)
--orthogonal: use orthogonal classifier
--MIFR: use L-MIFR algorithm [3]
--HSIC: use ReBias algorithm [4]
--LAFTR: use LAFTR algorithm [5]
[1] Remi Tachet des Combes, Han Zhao, Yu-Xiang Wang, and Geoffrey J. Gordon. Domain adaptation with conditional distribution matching and generalized label shift. ArXiv, abs/2003.04475, 2020.
[2] Rui Shu, H. Bui, H. Narui, and S. Ermon. A dirt-t approach to unsupervised domain adaptation. ArXiv, abs/1802.08735, 2018.
[3] Jiaming Song, Pratyusha Kalluri, Aditya Grover, Shengjia Zhao, and S. Ermon. Learning controllable fair representations. In AISTATS, 2019.
[4] Hyojin Bahng, Sanghyuk Chun, Sangdoo Yun, Jaegul Choo, and Seong Joon Oh. Learning de-biased representations with biased representations. In ICML, 2020.
[5] David Madras, Elliot Creager, T. Pitassi, and R. Zemel. Learning adversarially fair and transferable representations. In ICML, 2018.
The implementation of this repo is based on / inspired by:
- https://github.com/facebookresearch/DomainBed (code structure).
- https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix (code structure)
- https://github.com/ozanciga/dirt-t (VADA code)
- https://github.com/Britefury/self-ensemble-visual-domain-adapt (data generation)