Created by Yongheng Zhao, Tolga Birdal, Haowen Deng, Federico Tombari from TUM.
See this link for original README documentation
Custom functions:
- generate capsule dataset for transfer learning,
- train beta-vae with capsules,
- decode and visualize capsules using default capsnet checkpoint
Since the default CD package is extremely buggy, we switched to a new CD package provided by chrdiller. Link: https://github.com/chrdiller/pyTorchChamferDistance
The code is based on PyTorch. It has been tested with Python 3.8, PyTorch 1.6.0, CUDA 11.0(or higher) on Ubuntu 20.04.
Install h5py for Python:
sudo apt-get install libhdf5-dev
sudo pip install h5py
To visualize the training process in PyTorch, consider installing TensorBoard.
If you have GUI enabled, to visualize the reconstructed point cloud, consider installing Open3D.
pip3 install open3d
cd dataset
bash download_shapenet_part16_catagories.sh
ShapeNet Core with 13 categories (refered from AtlasNet.)
cd dataset
bash download_shapenet_core13_catagories.sh
ShapeNet Core with 55 categories (refered from FoldingNet.)
cd dataset
bash download_shapenet_core55_catagories.sh
You can download the pre-trained models here.
We provide an example demonstrating the basic usage in the folder 'mini_example'.
To visualize the reconstruction from latent capsules with our pre-trained model:
cd mini_example/AE
python viz_reconstruction.py --model ../../checkpoints/shapenet_part_dataset_ae_200.pth
To train a point capsule auto encoder with ShapeNetPart dataset by yourself:
cd mini_example/AE
python train_ae.py
To train a point capsule auto encoder with another dataset:
cd apps/AE
python train_ae.py --dataset < shapenet_part, shapenet_core13, shapenet_core55 >
To monitor the training process, use TensorBoard by specifying the log directory:
tensorboard --logdir log
To test the reconstruction accuracy:
python test_ae.py --dataset < > --model < >
e.g.
python test_ae.py --dataset shapenet_core13 --model ../../checkpoints/shapenet_core13_dataset_ae_230.pth
To visualize the reconstructed points:
python viz_reconstruction.py --dataset < > --model < >
e.g.
python viz_reconstruction.py --dataset shapenet_core13 --model ../../checkpoints/shapenet_core13_dataset_ae_230.pth