Skip to content

This is a repository that implements the Dense NN Retrieval Evaluation used for evaluating the In-Context Learning Capabilities of Vision Encoders.

License

Notifications You must be signed in to change notification settings

vpariza/open-hummingbird-eval

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

13 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Description

This repository is a reproduction repository that implements the Dense NN Retrieval Evaluation method introduced by Balažević et al. "Towards In-context Scene Understanding", NeurIPS 2023.

Briefly, it evaluates the effectiveness of spatial features acquired from a vision encoder, to associate themselves to relevant features from a dataset (validation), through the utilization of a k-NN classifier/retriever that operates across various proportions of training data.

Hummingbird Evaluation Image taken from "Towards In-context Scene Understanding", NeurIPS 2023.

This evaluation approach helps understand scenes by comparing new images with ones we already know. We start by showing it a bunch of densely labeled images. It densely encodes the images such that we have both the encoded patches (top-left section) and their labels (bottom-left section) as taken from a set of image-label examples given (left part). Then, we give it new images to describe (right part) without the labels, which again densely encodes. Then, it compares parts (encoded patches) of each of the given images with similar parts in the examples it knows. By looking at what's closest, it figures out what is the potential label for that part and therefore on what the new image might be showing. This is a flexible approach because it doesn't assume anything about the labels.

Reproduction done by:

  • Valentinos Pariza
  • Mohammadreza Salehi
  • Yuki M. Asano

At the University of Amsterdam (UvA)

Notes

  • For any questions/issues etc. please open a github issue on this repository.
  • If you find this repository useful, please consider starring and citing.

Results we got with our implementation on Pascal VOC

For the experiments below we used two dataset augmentation epochs and also we used image size of (512,512) for the dino and (504,504) for dinov2.

arch model PVOC (mIoU) per Memory Size PVOC (mIoU)
from orig. Paper
1024*102 1024*103 1024*104 1024*104
ViT-S/16 dino 37.2 43.1 46.6 -
ViT-B/16 dino 44.9 50.8 55.7 55.9
ViT-S/14 dinov2 70.2 74.9 77.0 -
ViT-B/14 dinov2 69.1 74.6 76.9 -
ViT-L/14 dinov2 64.6 71.7 74.8 -
ViT-G/14 dinov2 62.3 69.9 73.6 -

Usage

Example on how to Evaluate dino with the Hummingbird (Dense NN Retrieval) Evaluation on Pascal VOC

import torch
from src.hbird_eval import hbird_evaluation
# Parameters for the model dino
device = 'cuda'
input_size = 224
batch_size = 64
patch_size = 16
embed_dim = 384
model = torch.hub.load('facebookresearch/dino:main', 'dino_vits16')

# Define the function to extract features from the model
# Input to the function is the model and the images
# Output of the function is the features extracted from the model 
# and optionally the attention maps
fn = lambda model, imgs: (model.get_intermediate_layers(imgs)[0][:, 1:], None)


# Evaluate the model using the Full In-Context Learning Hummingbird  
# or Dense k-NN Retrieval Evaluation on the Pascal VOC Dataset
hbird_miou = hbird_evaluation(model.to(device), 
        d_model=embed_dim,          # size of the embedding feature vectors of patches
        patch_size=patch_size, 
        batch_size = batch_size, 
        input_size=224,             
        augmentation_epoch=1,       # how many iterations of augmentations to use on top of 
                                    # the training dataset in order to generate the memory
        device=device,              
        return_knn_details=False,   # whether to return additional NNs details
        n_neighbours=30,           # the number of neighbors to fetch per image patch
        nn_params=None,             # Other parameters to be used for the k-NN operator
        ftr_extr_fn=fn,             # function that extracts image patch features with 
                                    # a vision encoder
        dataset_name='voc',         # the name of the dataset to use, 
                                    # currently only Pascal VOC is included.
        data_dir='<the path to the Pascal VOC Dataset>',    # path to the dataset 
                                                            # to use for evaluation
        memory_size=None)           # How much you want to limit your dataset, 
                                    # None if to be left unbounded
print('Dense NN Ret - miou score:', hbird_miou) 

Example on how to Evaluate dinov2 with Dense NN Retrieval on Pascal VOC

import torch
from src.hbird_eval import hbird_evaluation
# Parameters for the model dino
device = 'cuda'
input_size = 224
batch_size = 256
patch_size = 14
embed_dim = 384
model = torch.hub.load('facebookresearch/dinov2', 'dinov2_vits14')

# Define the function to extract features from the model
# Input to the function is the model and the images
# Output of the function is the features extracted from the model 
# and optionally the attention maps
fn = lambda model, imgs: (model.forward_features(imgs)['x_norm_patchtokens'], None)


# Evaluate the model using the Full In-Context Learning Hummingbird  
# or Dense k-NN Retrieval Evaluation on the Pascal VOC Dataset
hbird_miou = hbird_evaluation(model.to(device), 
        d_model=embed_dim,          # size of the embedding feature vectors of patches
        patch_size=patch_size, 
        batch_size = batch_size, 
        input_size=224,             
        augmentation_epoch=1,       # how many iterations of augmentations to use on top of 
                                    # the training dataset in order to generate the memory
        device=device,              
        return_knn_details=False,   # whether to return additional NNs details
        n_neighbours=30,           # the number of neighbors to fetch per image patch
        nn_params=None,             # Other parameters to be used for the k-NN operator
        ftr_extr_fn=fn,             # function that extracts image patch features with 
                                    # a vision encoder
        dataset_name='voc',         # the name of the dataset to use, 
                                    # currently only Pascal VOC is included.
        data_dir='<the path to the Pascal VOC Dataset>',    # path to the dataset 
                                                            # to use for evaluation
        memory_size=None)           # How much you want to limit your dataset, 
                                    # None if to be left unbounded
print('Dense NN Ret - miou score:', hbird_miou) 

Ready to use script

We also provide a ready to use Python script to run evaluations using DINO backbones. For example, to evaluate a ViT S/16 on the whole Pascal VOC dataset using a memory bank of size 1024*102 you can run the following command

python eval.py                  \
    --seed 42                   \
    --batch-size 64             \
    --input-size 512            \
    --patch-size 16             \
    --memory-size 102400        \
    --embeddings-size 384       \
    --data-dir VOCSegmentation  \
    --model dino_vits16

Setup

This is the section describing what is required to execute the Dense NN Retrieval Evaluation.

Python Libraries

The most prevalent libraries being used:

  • torch + torchvision
  • torchmetrics
  • scann
  • numpy
  • joblib

Dataset Setup

VOC Pascal

We provide you with a zipped version of the whole dataset as well as with two smaller versions of it:

The structure of the Pascal VOC dataset folder should be as follows:

dataset root.
└───SegmentationClass
│   │   *.png
│   │   ...
└───SegmentationClassAug # contains segmentation masks from trainaug extension 
│   │   *.png
│   │   ...
└───images
│   │   *.jpg
│   │   ...
└───sets
│   │   train.txt
│   │   trainaug.txt
│   │   val.txt
VOC Pascal

You can download the ADE20K dataset from Kaggle.

The structure of the ADE20K dataset folder should be as follows:

dataset root.
└───annotations
│   └───training
│   |   | *.png
│   |   │   ...
│   └───validation
│       | *.png
│       │   ...
└───images
│   └───training
│   |   | *.jpg
│   |   │   ...
│   └───validation
│       | *.jpg
│       │   ...
└───objectInfo150.txt
└───sceneCategories.txt

Examples

Basic example on how to download any of our dataset versions and evaluate a vision encoder with our implementation of the Hummingbird evaluation can be found at hbird_eval_example in the examples folder.

You can also open it in google colab at:

Open In Colab

Upcoming/Future Features

Stay tuned with our work because we will bring more support and extensions of our implementation for extra features.

Feature Description
Cityscapes Code and Results for the Dataset Cityscapes
NYUv2 Code and Results for the Dataset NYUv2

Contributors

n Username
1 @vpariza
2 @Smsd75
3 @yukimasano

Citations

If you find this repo helpful, please consider citing these works:

The original paper:

@inproceedings{
      balazevic2023towards,
      title={Towards In-context Scene Understanding},
      author={Ivana Balazevic and David Steiner and Nikhil Parthasarathy and Relja Arandjelovic and Olivier J Henaff},
      booktitle={Thirty-seventh Conference on Neural Information Processing Systems},
      year={2023},
      url={https://openreview.net/forum?id=FasIQqsJhe}
}

Our work and repository:

@misc{pariza2024hbird,
      author = {Pariza, Valentinos and Salehi, Mohammadreza and Asano, Yuki},
      month = {4},
      title = {Hummingbird Evaluation for vision encoders},
      url = {https://github.com/vpariza/open-hummingbird-eval},
      year = {2024}
}

About

This is a repository that implements the Dense NN Retrieval Evaluation used for evaluating the In-Context Learning Capabilities of Vision Encoders.

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 4

  •  
  •  
  •  
  •  

Languages