This repository is the official implementation of Local Disentanglement in Variational Auto-Encoders Using Jacobian L1 Regularization
To install requirements for a CUDA-enabled workstation (strongly recommended):
conda env create -f environment.yml
conda init bash
To install requirements without GPU support:
conda env create -f environment_nogpu.yml
conda init bash
An example script how to train a single JL1-VAE model using a small cache of three-dots data (20,000 images) for only 30,000 training batches of 64 images can be run using
./exampleScripts/train_jlonevae_threeDots.bash
To train the full three-dots models in the paper, run:
./experimentScripts/train_jlonevae/train_threeDots.bash
The first time that is run it will take a few minutes to create a cache of
500,000 training images in the data/
folder, and will train for 300,000 batches of 64 images.
Subsequent runs will re-use the same cache of images.
Training logs are written to the ./logs
directory and the trained model is
written to a subdirectory of ./trainedModels
(both as a PyTorch JIT module for use with
disentanglement_lib
and also using torch.save(model.state_dict(), ...)
).
Baseline models, for comparison, are trained by calling
./experimentScripts/train_baseline/train_standard_tf_models.bash
Those models and logs are stored to a subdirectory of the trainedStandardModels
folder and the model
numbers identify which configuration was used from
disentanglement_lib,
now trained on the three-dots dataset.
Based on reviewer feedback, we also include (in the appendix) results of regularizing by the L2 norm of the Jacobian. These models can be trained by running
./experimentScripts/train_jlonevae/train_threeDots_l2.bash
To train the mpi3d-multi models in the paper,
download mpi3d_real
(12 gigabytes, so takes a while to download) by running
cd data
./download_mpi3d_real.sh
cd ..
and then run
./experimentScripts/train_jlonevae/train_mpi3d_multi.bash
Training logs are written to the ./logs
directory and the trained model is
written to a subdirectory of ./trainedModels
(both as a PyTorch JIT module for use with
disentanglement_lib
and also using torch.save(model.state_dict(), ...)
).
To train the natural image models in the paper, first download the data from Bruno Olshausen's website by running
cd data
conda activate jlonevae
./download_natural_image_data.sh
./sampleNatualImagePatches.py
cd ..
Then, run
./experimentScripts/train_jlonevae/train_naturalImages.bash
To evaluate the models qualitatively, from the base directory start a jupyter notebook by running
conda activate jlonevae
jupyter notebook
Then, within the folder experimentScripts/visualizations
, open any of the
following Jupyter notebooks to view the associated Jacobian columns:
Figure2-ExampleJacobianValues_ThreeDots.ipynb
Figure3-ExampleJacobianValues-Mpi3d-multi.ipynb
For natural image data, you can create Jacobian embeddings for a sequence of nearby image crops by running
./experimentScripts/visualizations/createLatentJacobianImages_naturalImages.bash
To evaluate the three-dots models quantitatively, run
./experimentScripts/evaluate_jlonevae/evaluate_threeDots.bash
You can safely ignore the error ERROR:root:Path not found: local_mig_base.gin
.
To evaluate the baseline models quantitatively, run
./experimentScripts/evaluate_baseline/postprocess_baseline_threeDots.bash
./experimentScripts/evaluate_baseline/evaluate_baseline.bash
You can safely ignore the error ERROR:root:Path not found: local_mig_base.gin
.
To visualize the quantitative evaluations, from the base directory run
conda activate jlonevae
jupyter notebook
Then, within the folder experimentScripts/visualizations
, open any of the
following Jupyter notebooks to generate comparison plots:
Figure4-LinePlotMIGandModularity.ipynb
Figure5-LocalDisentanglementComparedToBaseline.ipynb
Pretrained JL1-VAE and β-VAE models for natural images, three-dots, and MPI3D-Multi are available here (external link) Baseline models (β-VAE, FactorVAE, DIP-VAE-I, DIP-VAE-II, β-TCVAE, and AnnealedVAE) trained on three-dots are available here (external link)
Our model achieves qualitatively (see image of Jacobian columns above) and quantitatively (see plot of local disentanglement scores below) better local disentanglement compared to baseline methods. More details can be found in our paper.
This repository is licensed under the Apache License, Version 2.0. To contribute, please create a pull request.