Pytorch implementation of "Deep Transformation-Invariant Clustering" paper (accepted at NeurIPS 2020 as an oral)
Check out our paper and webpage for details!
If you find this code useful, don't forget to star the repo ⭐ and cite the paper:
@inproceedings{monnier2020dticlustering,
title={{Deep Transformation-Invariant Clustering}},
author={Monnier, Tom and Groueix, Thibault and Aubry, Mathieu},
booktitle={NeurIPS},
year={2020},
}
conda env create -f environment.yml
conda activate dti-clustering
Optional: some monitoring routines are implemented, you can use them by specifying the
visdom port in the config file. You will need to install visdom
from source beforehand
git clone https://github.com/facebookresearch/visdom
cd visdom && pip install -e .
Following script will download affNIST-test
and FRGC
datasets, as well as our unfiltered
Instagram collections associated to
#santaphoto and
#weddingkiss:
./download_data.sh
NB: it may happen that gdown
hangs, if so you can download them by hand with following
gdrive links, then unzip and move them to the datasets
folder:
cuda=gpu_id config=filename.yml tag=run_tag ./pipeline.sh
where:
gpu_id
is a target cuda device id,filename.yml
is a YAML config located inconfigs
folder,run_tag
is a tag for the experiment.
Results are saved at runs/${DATASET}/${DATE}_${run_tag}
where DATASET
is the dataset name
specified in filename.yml
and DATE
is the current date in mmdd
format. Some training
visual results like prototype evolutions and transformation prediction examples will be
saved. Here is an example of learned MNIST prototypes and transformation predictions for a
given query image:
cuda=gpu_id config=mnist_test.yml tag=dtikmeans ./multi_pipeline.sh
Switch the model name to dtigmm
in the config file to reproduce results for DTI GMM.
Available configs are:
- affnist_test.yml
- fashion_mnist.yml
- frgc.yml
- mnist.yml
- mnist_1k.yml
- mnist_color.yml
- mnist_test.yml
- svhn.yml
- usps.yml
- (skip if you already downloaded data using script above) Create a santaphoto dataset
by running
process_insta_santa.sh
script. It can take a while to scrape the 10k posts from Instagram. - Launch training with
cuda=gpu_id config=instagram.yml tag=santaphoto ./pipeline.sh
That's it! You can apply the process to other IG hashtags like #trevifountain or #weddingkiss and discover prototypes similar to:
- You need to download desired landmarks from the original MegaDepth project webpage, e.g. Florence Cathedral
- Move images to a
datasets/megadepth/firenze/train
folder - Launch training with
cuda=gpu_id config=megadepth.yml tag=firenze ./pipeline.sh
You should end up with 20 learned prototypes and random sample examples in each cluster. To assess the quality of clustering, you can visualized for each cluster, the prototype, random samples and transformed prototypes like:
If you like this project, please check out related works from our group: