This repo contains code for our paper: Generalized Category Discovery
Given a dataset, some of which is labelled, Generalized Category Discovery is the task of assigning a category to all the unlabelled instances. Unlabelled instances could come from labelled or 'New' classes.
Updates to paper since pre-print (updated PDF available here, ArXiv updating soon)
- We introduced a more rigorous evaluation metric - when computing ACC, we compute the Hungarian algorithm only once across all unlabelled data.
- This single set of linear assignments is then used to compute ACC on 'Old' and 'New' class subsets (see Appendix E)
- Practically, this involves switching from 'v1' to 'v2' evaluation in
./project_utils/cluster_and_log_utils.py
pip install -r requirements.txt
Set paths to datasets, pre-trained models and desired log directories in config.py
Set SAVE_DIR
(logfile destination) and PYTHON
(path to python interpreter) in bash_scripts
scripts.
We use fine-grained benchmarks in this paper, including:
We also use generic object recognition datasets, including:
- CIFAR-10/100 and ImageNet
Train representation:
bash bash_scripts/contrastive_train.sh
Extract features: Extract features to prepare for semi-supervised k-means.
It will require changing the path for the model with which to extract features in warmup_model_dir
bash bash_scripts/extract_features.sh
Fit semi-supervised k-means:
bash bash_scripts/k_means.sh
Under the old evaluation metric ('v1') we found that semi-supervised k-means consistently boosted performance over standard k-means, on 'Old' and 'New' data subsets. When we changed to 'v2' evaluation, we re-evaluated models in Tables {2,3,5} (including the ablation) and updated the figures.
However, recently, we have found that SS-k-means can be sensitive to bad initialisation under 'v2', and can sometimes lower performance on some datasets. Increasing the number of inits for SS-k-means can help. We are investigating this further now - suggestions and PRs welcome!
Results from re-running models with this repo compared to reported numbers:
Dataset | All | Old | New |
---|---|---|---|
Stanford Cars (paper) | 39.0 | 57.6 | 29.9 |
Stanford Cars (repo) | 39.9 | 58.5 | 30.9 |
CIFAR100 (paper) | 70.8 | 77.6 | 57.0 |
CIFAR100 (repo) | 71.3 | 77.4 | 59.1 |
If you use this code in your research, please consider citing our paper:
@InProceedings{vaze2022gcd,
title={Generalized Category Discovery},
author={Sagar Vaze and Kai Han and Andrea Vedaldi and Andrew Zisserman},
booktitle={IEEE Conference on Computer Vision and Pattern Recognition},
year={2022}}