Python3 implementation of the papers Hierarchical Sliced Wasserstein distance
Details of the model architecture and experimental results can be found in our papers.
title={Hierarchical Sliced Wasserstein Distance},
author={Khai Nguyen and Tongzheng Ren and Huy Nguyen and Litu Rout and Tan Nguyen and Nhat Ho},
journal={International Conference on Learning Representations},
Please CITE our paper whenever this repository is used to help produce published results or incorporated into other software.
This implementation is made by Khai Nguyen. README is on updating process.
The code is implemented with Python (3.8.8) and Pytorch (1.9.0).
- (Hierarchical) Sliced Wasserstein Generators
- : this file contains arguments for training.
- : this file implements dataloaders.
- : this file implements training functions.
- : this file is the main file for running SW.
- : this file is the main file for running Max-SW.
- : this file is the main file for running HSW.
- : this file is the main file for running Max-HSW.
- models : this folder contains neural networks architecture.
- utils : this folder contains implementation of fid score and Inception score.
- fid_stat : this folder contains statistic files for fID score.
- --dataset : type of dataset {"cifar10","tinyimagenet","celeba""}
- --img_size : size of images
- --dis_bs : size of mini-batches
- --model : "sngan_{dataset}" *. --eval_batch_size : batchsize for computing FID
- --Ls : "k,L" with k is the number of bottleneck projections for (Max-)HSW and L is the number of projections for HSW
- --L : the number of projections fo SW
- --s_lr : slice learning rate (for Max-SW and Max-HSW)
- --s_max_iter : max iterations of gradient update (for Max-SW and Max-HSW)
python -gen_bs 128 -dis_bs 128 --data_path ./data --dataset celeba --img_size 64 --max_iter 50000 --model sngan_celeba --latent_dim 128 --gf_dim 256 --df_dim 128 --g_spectral_norm False --d_spectral_norm True --g_lr 0.0002 --d_lr 0.0002 --beta1 0.0 --beta2 0.9 --init_type xavier_uniform --n_critic 5 --val_freq 20 --Ls 70,2000 --exp_name hsw --random_seed 1
The structure of this repo is largely based on sngan.pytorch and CSW.