Skip to content

travers-rhodes/jlonevae

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

45 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Local Disentanglement in Variational Auto-Encoders Using Jacobian L1 Regularization

This repository is the official implementation of Local Disentanglement in Variational Auto-Encoders Using Jacobian L1 Regularization

Local Disentanglement Jacobians for Three Dots

Requirements

To install requirements for a CUDA-enabled workstation (strongly recommended):

conda env create -f environment.yml
conda init bash

To install requirements without GPU support:

conda env create -f environment_nogpu.yml
conda init bash

Training

Small example

An example script how to train a single JL1-VAE model using a small cache of three-dots data (20,000 images) for only 30,000 training batches of 64 images can be run using

./exampleScripts/train_jlonevae_threeDots.bash

Full experiments

Three-dots Data

To train the full three-dots models in the paper, run:

./experimentScripts/train_jlonevae/train_threeDots.bash

The first time that is run it will take a few minutes to create a cache of 500,000 training images in the data/ folder, and will train for 300,000 batches of 64 images. Subsequent runs will re-use the same cache of images.

Training logs are written to the ./logs directory and the trained model is written to a subdirectory of ./trainedModels (both as a PyTorch JIT module for use with disentanglement_lib and also using torch.save(model.state_dict(), ...)).

Three-dots Data baseline models

Baseline models, for comparison, are trained by calling

./experimentScripts/train_baseline/train_standard_tf_models.bash

Those models and logs are stored to a subdirectory of the trainedStandardModels folder and the model numbers identify which configuration was used from disentanglement_lib, now trained on the three-dots dataset.

Three-dots Data with JL2 Regularization (appendix)

Based on reviewer feedback, we also include (in the appendix) results of regularizing by the L2 norm of the Jacobian. These models can be trained by running

./experimentScripts/train_jlonevae/train_threeDots_l2.bash

MPI3D-Multi

To train the mpi3d-multi models in the paper, download mpi3d_real (12 gigabytes, so takes a while to download) by running

cd data
./download_mpi3d_real.sh
cd ..

and then run

./experimentScripts/train_jlonevae/train_mpi3d_multi.bash

Training logs are written to the ./logs directory and the trained model is written to a subdirectory of ./trainedModels (both as a PyTorch JIT module for use with disentanglement_lib and also using torch.save(model.state_dict(), ...)).

Natural Images

To train the natural image models in the paper, first download the data from Bruno Olshausen's website by running

cd data
conda activate jlonevae
./download_natural_image_data.sh
./sampleNatualImagePatches.py
cd ..

Then, run

./experimentScripts/train_jlonevae/train_naturalImages.bash

Evaluation

Qualtiative

To evaluate the models qualitatively, from the base directory start a jupyter notebook by running

conda activate jlonevae
jupyter notebook

Then, within the folder experimentScripts/visualizations, open any of the following Jupyter notebooks to view the associated Jacobian columns:

Figure2-ExampleJacobianValues_ThreeDots.ipynb
Figure3-ExampleJacobianValues-Mpi3d-multi.ipynb

For natural image data, you can create Jacobian embeddings for a sequence of nearby image crops by running

./experimentScripts/visualizations/createLatentJacobianImages_naturalImages.bash

Quantitative

To evaluate the three-dots models quantitatively, run

./experimentScripts/evaluate_jlonevae/evaluate_threeDots.bash

You can safely ignore the error ERROR:root:Path not found: local_mig_base.gin.

To evaluate the baseline models quantitatively, run

./experimentScripts/evaluate_baseline/postprocess_baseline_threeDots.bash
./experimentScripts/evaluate_baseline/evaluate_baseline.bash

You can safely ignore the error ERROR:root:Path not found: local_mig_base.gin.

To visualize the quantitative evaluations, from the base directory run

conda activate jlonevae
jupyter notebook

Then, within the folder experimentScripts/visualizations, open any of the following Jupyter notebooks to generate comparison plots:

Figure4-LinePlotMIGandModularity.ipynb
Figure5-LocalDisentanglementComparedToBaseline.ipynb  

Pre-trained Models

Pretrained JL1-VAE and β-VAE models for natural images, three-dots, and MPI3D-Multi are available here (external link) Baseline models (β-VAE, FactorVAE, DIP-VAE-I, DIP-VAE-II, β-TCVAE, and AnnealedVAE) trained on three-dots are available here (external link)

Results

Our model achieves qualitatively (see image of Jacobian columns above) and quantitatively (see plot of local disentanglement scores below) better local disentanglement compared to baseline methods. More details can be found in our paper.

Local Disentanglement Jacobians for Three Dots

Contributing

This repository is licensed under the Apache License, Version 2.0. To contribute, please create a pull request.

About

No description, website, or topics provided.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published