This repository is the official implementation of "Hyperbolic VAE via Latent Gaussian Distributions" accepted at NeurIPS 2023.
We propose a Gaussian manifold variational auto-encoder (GM-VAE) whose latent space consists of a set of Gaussian distributions. It is known that the set of the univariate Gaussian distributions with the Fisher information metric form a hyperbolic space, which we call a Gaussian manifold. To learn the VAE endowed with the Gaussian manifolds, we propose a pseudo-Gaussian manifold normal distribution based on the Kullback-Leibler divergence, a local approximation of the squared Fisher-Rao distance, to define a density over the latent space. In experiments, we demonstrate the efficacy of GM-VAE on two different tasks: density estimation of image datasets and environment modeling in model-based reinforcement learning. GM-VAE outperforms the other variants of hyperbolic- and Euclidean-VAEs on density estimation tasks and shows competitive performance in model-based reinforcement learning. We observe that our model provides strong numerical stability, addressing a common limitation reported in previous hyperbolic-VAEs.
- Install pytorch and torchvision. The recommended versions are pytorch 2.0.1 and torchvision 0.15.2
- Run
pip install -r requirements.txt
. - Install geoopt py running the following command:
pip install git+https://github.com/geoopt/geoopt.git
. - Prepare the datasets by running the script:
sh scripts/download.sh
.
You can reproduce the experiments from our paper using the following commands:
> python train_vae.py --dist=PGMNormal --exp_name=reproduce --seed 1 --c -1.0 --latent_dim=4 --task=Breakout
> python train_vae.py --dist=PGMNormal --exp_name=reproduce --seed 1 --c -1.0 --latent_dim=35 --task=CUB
> python train_vae.py --dist=PGMNormal --exp_name=reproduce --seed 1 --c -1.0 --latent_dim=35 --task=Food101
> python train_vae.py --dist=PGMNormal --exp_name=reproduce --seed 1 --c -1.0 --latent_dim=35 --task=Oxford102
You can also reproduce the entire table by running the wandb sweeps in scripts/reproduce_breakout.yaml
and scripts/reproduce_rgb.yaml
.
Please cite our paper if you use the model or this code in your own work:
@inproceedings{
anonymous2023hyperbolic,
title={Hyperbolic VAE via Latent Gaussian Distributions},
author={Seunghyuk Cho and Juyong Lee and Dongwoo Kim},
booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
year={2023},
url={https://openreview.net/forum?id=FNn4zibGvw}
}