- Paper: Semi-Supervised Learning with Deep Generative Models
- Authors: Diederik P. Kingma, Danilo J. Rezende, Shakir Mohamed, Max Welling
- Original Implementation: github
Implements the latent-feature discriminative model (M1) and generative semi-supervised model (M2) from the paper in TensorFlow (python).
- TensorFlow >= 0.8.0 (due to prettytensor, might work with older versions of prettytensor - not tested)
- prettytensor
- numpy
- optionally matplotlib, seaborn for VAE images
- To train latent-feature model (M1) run train_vae.py. Parameters set in same file.
- To train M1+M2 classifier run train_classifier.py. Parameters set in same file. Location of saved M1 (VAE) model must be specified.
- Using the provided VAE model and the given parameters should produce an accuracy of about 95.4% on the test set using 100 labelled examples.