Skip to content
/ URL Public
forked from VICO-UoE/URL

Universal Representation Learning from Multiple Domains for Few-shot Classification - ICCV 2021, Cross-domain Few-shot Learning with Task-specific Adapters - CVPR 2022

License

Notifications You must be signed in to change notification settings

hummarow/URL

 
 

Repository files navigation

Universal Representation Learning and Task-specific Adaptation for Few-shot Learning

A universal representation learning algorithm that learns a set of well-generalized representations via a single universal network from multiple diverse visual datasets and task-specific adaptation techniques for few-shot learning.

Universal Representation Learning from Multiple Domains for Few-shot Classification,
Wei-Hong Li, Xialei Liu, Hakan Bilen,
ICCV 2021 (arXiv 2103.13841)

Cross-domain Few-shot Learning with Task-specific Adapters,
Wei-Hong Li, Xialei Liu, Hakan Bilen,
CVPR 2022 (arXiv 2107.00358)

Universal Representations: A Unified Look at Multiple Task and Domain Learning,
Wei-Hong Li, Xialei Liu, Hakan Bilen,
Preprint 2022 (arXiv 2204.02744)

Updates

Features at a glance

  • We train a single universal (task-agnostic) network on 8 visual (training) datasets on Meta-dataset: ImageNet, Omniglot, Aircraft, Birds, Textures, Quick Draw, Fungi, VGG Flower, with state-of-the-art performances on all (13) testing datasets for few-shot learning.

  • During meta-testing, the universal representations can be efficiently adapted by our proposed pre-classifier alignment (a linear transformation) learned on the support set to transform the representations to a more discriminative space.

  • We propose to attach a set of light weight task-specific adapters to the universal network (the universal network can be learned from multiple datasets or one single diverse dataset, e.g. ImageNet) and learn task-specific adapters on the support set from scratch for adapting the few-shot model to the tasks from unseen domains.

  • We systematically study various combinations of several design choices for task-specific adaptation, which have not been explored before, including adapter connection types (serial or residual), parameterizations (matrix and its decomposed variations, channelwise operations) and estimation of task-specific parameters.

  • Attaching parameteric adapters in matrix form to convolutional layers with residual connections significantly boosts the state-of-the-art performance in most domains, especially resulting in superior performance in unseen domains on Meta-Dataset

  • In this repo, we provide code of the URL, the best task adaptation strategy in TSA, and other baselines like SDL, vanilla MDL and other evaluation settings.

Main results on Meta-dataset

  • Multi-domain setting (meta-train on 8 datasets and meta-test on 13 datasets).
Test Datasets TSA (Ours) URL (Ours) MDL Best SDL tri-M [8] FLUTE [7] URT [6] SUR [5] Transductive CNAPS [4] Simple CNAPS [3] CNAPS [2]
Avg rank 1.5 2.7 7.1 6.7 5.5 5.1 6.7 6.9 5.7 7.2 -
Avg Seen 80.2 80.0 76.9 76.3 74.5 76.2 76.7 75.2 75.1 74.6 71.6
Avg Unseen 77.2 69.3 61.7 61.9 69.9 69.9 62.4 63.1 66.5 65.8 -
Avg All 79.0 75.9 71.1 70.8 72.7 73.8 71.2 70.5 71.8 71.2 -
  • Single-domain setting (Meta-train on ImageNet and meta-test on 13 datasets).
Test Datasets TSA-ResNet34 (Ours) TSA-ResNet18 (Ours) CTX-ResNet34 [10] ProtoNet-ResNet34 [10] FLUTE [7] BOHB [9] ALFA+fo-Proto-MAML [1] fo-Proto-MAML [1] ProtoNet [1] Finetune [1]
Avg rank 1.5 2.8 1.8 5.5 8.9 6.0 5.3 7.0 8.3 7.9
Avg Seen 63.7 59.5 62.8 53.7 46.9 51.9 52.8 49.5 50.5 45.8
Avg Unseen 76.2 71.9 75.6 61.1 53.2 60.0 62.4 58.4 56.7 58.2
Avg All 74.9 70.7 74.3 60.4 52.6 59.2 61.4 57.5 56.1 57.0

Model Zoo

Dependencies

This code requires the following:

  • Python 3.6 or greater
  • PyTorch 1.0 or greater
  • TensorFlow 1.14 or greater

Installation

  • Clone or download this repository.
  • Configure Meta-Dataset:
    • Follow the "User instructions" in the Meta-Dataset repository for "Installation" and "Downloading and converting datasets".
    • Edit ./meta-dataset/data/reader.py in the meta-dataset repository to change dataset = dataset.batch(batch_size, drop_remainder=False) to dataset = dataset.batch(batch_size, drop_remainder=True). (The code can run with drop_remainder=False, but in our work, we drop the remainder such that we will not use very small batch for some domains and we recommend to drop the remainder for reproducing our methods.)
    • To test unseen domain (out-of-domain) performance on additional datasets, i.e. MNIST, CIFAR-10 and CIFAR-100, follow the installation instruction in the CNAPs repository to get these datasets.

Initialization

  1. Before doing anything, first run the following commands.

    ulimit -n 50000
    export META_DATASET_ROOT=<root directory of the cloned or downloaded Meta-Dataset repository>
    export RECORDS=<the directory where tf-records of MetaDataset are stored>
    

    Note the above commands need to be run every time you open a new command shell.

  2. Enter the root directory of this project, i.e. the directory where this project was cloned or downloaded.

Universal Representation Learning from Multiple Domains for Few-shot Classification

Figure 1. URL - Universal Representation Learning.

Train the Universal Representation Learning Network

  1. The easiest way is to download our pre-trained URL model and evaluate its feature using our Pre-classifier Alignment (PA). To download the pretrained URL model, one can use gdown (installed by pip install gdown) and execute the following command in the root directory of this project:

    gdown https://drive.google.com/uc?id=1Dv8TX6iQ-BE2NMpfd0sQmH2q4mShmo1A && md5sum url.zip && unzip url.zip -d ./saved_results/ && rm url.zip
    
    

    This will donwnload the URL model and place it in the ./saved_results directory. One can evaluate this model by our PA (see the Meta-Testing step)

  2. Alternatively, one can train the model from scratch: 1) train 8 single domain learning networks; 2) train the universal feature extractor as follow.

Train Single Domain Learning Networks

  1. The easiest way is to download our pre-trained models and use them to obtain a universal set of features directly. To download single domain learning networks, execute the following command in the root directory of this project:

    gdown https://drive.google.com/uc?id=1MvUcvQ8OQtoOk1MIiJmK6_G8p4h8cbY9 && md5sum sdl.zip && unzip sdl.zip -d ./saved_results/ && rm sdl.zip
    

    This will download all single domain learning models and place them in the ./saved_results directory of this project.

  2. Alternatively, instead of using the pretrained models, one can train the models from scratch. To train 8 single domain learning networks, run:

    ./scripts/train_resnet18_sdl.sh
    

Train the Universal Feature Extractor

To learn the universal feature extractor by distilling the knowledge from pre-trained single domain learning networks, run:

./scripts/train_resnet18_url.sh

Meta-Testing with Pre-classifier Alignment (PA)

Figure 2. PA - Pre-classifier Alignment for Adapting Features in Meta-test.

This step would run our Pre-classifier Alignment (PA) procedure per task to adapt the features to a discriminate space and build a Nearest Centroid Classifier (NCC) on the support set to classify query samples, run:

./scripts/test_resnet18_pa.sh

Cross-domain Few-shot Learning with Task-specific Adapters

Figure 3. Cross-domain Few-shot Learning with Task-specific Adapters (TSA).

We provide code for attaching task-specific adapters (TSA) to a single universal network learned from meta-train and learn the task-specific adapters on the support set. One can download our pre-trained URL model (see here to download the URL or SDL models or train them from scratch) and evaluate its feature adapted by residual adapters in matrix form and pre-classifier alignment, run:

./scripts/test_resnet18_tsa.sh

One may want to train the model from scratch from the Meta-training step. For single-domain learning network, see here to learn a single network from ImageNet with ResNet-18. For multi-domain learning setting, one can learn a URL model (see here) or learn a vanilla MDL model (see here). Note that, one may need to amend the input of --model.name and --model.dir in ./scripts/test_resnet18_tsa.sh to the model learned from meta-training and amend --test.mode to sdl if the backbone is learned from ImageNet only in meta-training and then run the TSA.

We also provide implementation of different options for task-specific adapters, including connection topology (serial or residual), parameterizations (matrix or channel-wise), weight initializations (identity or random). See ./scripts/test_resnet18_tsa.sh for more details. Note that, you would obtain slightly different results compared with the ones in in Table 3 in our TSA paper as mentioned in google-research/meta-dataset#54. One can set shuffle_buffer_size to 0 in ./data/meta_dataset_reader.py to obtain the same results as in Table 3 in our TSA paper, but I strongly suggest that one should re-run the experiments using our up-to-date code (the results with shuffle_buffer_size=1000 would be slightly different from the ones with shuffle_buffer_size=0 and the rankings will be the same).

Expected Results

Below are the results extracted from our papers. The results will vary from run to run by a percent or two up or down due to the fact that the Meta-Dataset reader generates different tasks each run, randomnes in training the networks and in TSA and PA optimization. Note, the results are updated with the up-to-date evaluation from Meta-Dataset. Make sure that you use the up-to-date code from the Meta-Dataset repository to convert the dataset and set shuffle_buffer_size=1000 as mentioned in google-research/meta-dataset#54.

Models trained on all datasets

Test Datasets TSA (Ours) URL (Ours) MDL Best SDL tri-M [8] FLUTE [7] URT [6] SUR [5] Transductive CNAPS [4] Simple CNAPS [3] CNAPS [2]
Avg rank 1.5 2.7 7.1 6.7 5.5 5.1 6.7 6.9 5.7 7.2 -
ImageNet 57.4±1.1  57.5±1.1  52.9±1.2  54.3±1.1  58.6±1.0  51.8±1.1  55.0±1.1  54.5±1.1  57.9±1.1  56.5±1.1  50.8±1.1 
Omniglot 95.0±0.4  94.5±0.4  93.7±0.5  93.8±0.5  92.0±0.6  93.2±0.5  93.3±0.5  93.0±0.5  94.3±0.4  91.9±0.6  91.7±0.5 
Aircraft 89.3±0.4  88.6±0.5  84.9±0.5  84.5±0.5  82.8±0.7  87.2±0.5  84.5±0.6  84.3±0.5  84.7±0.5  83.8±0.6  83.7±0.6 
Birds 81.4±0.7  80.5±0.7  79.2±0.8  70.6±0.9  75.3±0.8  79.2±0.8  75.8±0.8  70.4±1.1  78.8±0.7  76.1±0.9  73.6±0.9 
Textures 76.7±0.7  76.2±0.7  70.9±0.8  72.1±0.7  71.2±0.8  68.8±0.8  70.6±0.7  70.5±0.7  66.2±0.8  70.0±0.8  59.5±0.7 
Quick Draw 82.0±0.6  81.9±0.6  81.7±0.6  82.6±0.6  77.3±0.7  79.5±0.7  82.1±0.6  81.6±0.6  77.9±0.6  78.3±0.7  74.7±0.8 
Fungi 67.4±1.0  68.8±0.9  63.2±1.1  65.9±1.0  48.5±1.0  58.1±1.1  63.7±1.0  65.0±1.0  48.9±1.2  49.1±1.2  50.2±1.1 
VGG Flower 92.2±0.5  92.1±0.5  88.7±0.6  86.7±0.6  90.5±0.5  91.6±0.6  88.3±0.6  82.2±0.8  92.3±0.4  91.3±0.6  88.9±0.5 
Traffic Sign 83.5±0.9  63.3±1.2  49.2±1.0  47.1±1.1  63.0±1.0  58.4±1.1  50.1±1.1  49.8±1.1  59.7±1.1  59.2±1.0  56.5±1.1 
MSCOCO 55.8±1.1  54.0±1.0  47.3±1.1  49.7±1.0  52.8±1.1  50.0±1.0  48.9±1.1  49.4±1.1  42.5±1.1  42.4±1.1  39.4±1.0 
MNIST 96.7±0.4  94.5±0.5  94.2±0.4  91.0±0.5  96.2±0.3  95.6±0.5  90.5±0.4  94.9±0.4  94.7±0.3  94.3±0.4  -
CIFAR-10 80.6±0.8  71.9±0.7  63.2±0.8  65.4±0.8  75.4±0.8  78.6±0.7  65.1±0.8  64.2±0.9  73.6±0.7  72.0±0.8  -
CIFAR-100 69.6±1.0  62.6±1.0  54.7±1.1  56.2±1.0  62.0±1.0  67.1±1.0  57.2±1.0  57.1±1.1  61.8±1.0  60.9±1.1  -

Models trained on ImageNet only TODO

[1] Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, Hugo Larochelle; Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples; ICLR 2020.

[2] James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, Richard E. Turner; Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes; NeurIPS 2019.

[3] Peyman Bateni, Raghav Goyal, Vaden Masrani, Frank Wood, Leonid Sigal; Improved Few-Shot Visual Classification; CVPR 2020.

[4] Peyman Bateni, Jarred Barber, Jan-Willem van de Meent, Frank Wood; Enhancing Few-Shot Image Classification with Unlabelled Examples; WACV 2022.

[5] Nikita Dvornik, Cordelia Schmid, Julien Mairal; Selecting Relevant Features from a Multi-domain Representation for Few-shot Classification; ECCV 2020.

[6] Lu Liu, William Hamilton, Guodong Long, Jing Jiang, Hugo Larochelle; Universal Representation Transformer Layer for Few-Shot Image Classification; ICLR 2021.

[7] Eleni Triantafillou, Hugo Larochelle, Richard Zemel, Vincent Dumoulin; Learning a Universal Template for Few-shot Dataset Generalization; ICML 2021.

[8] Yanbin Liu, Juho Lee, Linchao Zhu, Ling Chen, Humphrey Shi, Yi Yang; A Multi-Mode Modulator for Multi-Domain Few-Shot Classification; ICCV 2021.

[9] Tonmoy Saikia, Thomas Brox, Cordelia Schmid; Optimized Generic Feature Learning for Few-shot Classification across Domains; arXiv 2020.

[10] Carl Doersch, Ankush Gupta, Andrew Zisserman; CrossTransformers: spatially-aware few-shot transfer; NeurIPS 2020.

Other Usage

Train a Vanilla Multi-domain Learning Network

To train a vanilla multi-domain learning network (MDL) on Meta-Dataset, run:

./scripts/train_resnet18_mdl.sh

Other Classifiers for Meta-Testing (optional)

One can use other classifiers for meta-testing, e.g. use --test.loss-opt to select nearest centroid classifier (ncc, default), support vector machine (svm), logistic regression (lr), Mahalanobis distance from Simple CNAPS (scm), or k-nearest neighbor (knn); use --test.feature-norm to normalize feature (l2) or not for svm and lr; use --test.distance to specify the feature similarity function (l2 or cos) for NCC.

To evaluate the feature extractor with NCC and cosine similarity, run:

python test_extractor.py --test.loss-opt ncc --test.feature-norm none --test.distance cos --model.name=url --model.dir <directory of url> 

Five-shot and Five-way-one-shot Meta-test (optional)

One can evaluate the feature extractor in meta-testing for five-shot or five-way-one-shot setting by setting --test.type as '5shot' or '1shot', respectively.

To test the feature extractor for varying-way-five-shot on the test splits of all datasets, run:

python test_extractor.py --test.type 5shot --test.loss-opt ncc --test.feature-norm none --test.distance cos --model.name=url --model.dir <directory of url>

If one wants to evaluate our proposed URL and TSA method in 5-shot or 5-way-1-shot settings, please use test_extractor_pa.py and test_extractor_tsa.py with setting --test.type as '5shot' or '1shot'.

Acknowledge

We thank authors of Meta-Dataset, SUR, Residual Adapter for their source code.

Contact

For any question, you can contact Wei-Hong Li.

Citation

If you use this code, please cite our papers:

@article{li2022Universal,
    author    = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
    title     = {Universal Representations: A Unified Look at Multiple Task and Domain Learning},
    journal   = {arXiv preprint arXiv:2204.02744},
    year      = {2022}
}

@inproceedings{li2022TaskSpecificAdapter,
    author    = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
    title     = {Cross-domain Few-shot Learning with Task-specific Adapters},
    booktitle = {IEEE/CVF International Conference on Computer Vision and Pattern Recognition (CVPR)},
    month     = {June},
    year      = {2022}
}

@inproceedings{li2021Universal,
    author    = {Li, Wei-Hong and Liu, Xialei and Bilen, Hakan},
    title     = {Universal Representation Learning From Multiple Domains for Few-Shot Classification},
    booktitle = {IEEE/CVF International Conference on Computer Vision (ICCV)},
    month     = {October},
    year      = {2021},
    pages     = {9526-9535}
}

@inproceedings{li2020knowledge,
    author    = {Li, Wei-Hong and Bilen, Hakan},
    title     = {Knowledge distillation for multi-task learning},
    booktitle = {European Conference on Computer Vision (ECCV) Workshop},
    year      = {2020}
}

About

Universal Representation Learning from Multiple Domains for Few-shot Classification - ICCV 2021, Cross-domain Few-shot Learning with Task-specific Adapters - CVPR 2022

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages

  • Python 97.4%
  • Shell 2.6%