Skip to content

Library that provides metrics to assess representation quality

License

Notifications You must be signed in to change notification settings

BARL-SSL/reptrix

Repository files navigation

Documentation Status PyPI

Representation quality metrics for pretrained deep models! ⭐


Reptrix


Reptrix

About

Reptrix, short for Representation Metrics, is a PyTorch library designed to simplify the evaluation of representation quality metrics in pretrained deep neural networks. Reptrix offers a suite of recently proposed metrics, predominanty in the vision self-supervised learning literature, that are essential for researchers and engineers focusing on design, deployment, evaluation and interpretability of deep neural networks in computer vision settings.

Key Features:

  • Comprehensive Metric Suite: Includes a variety of metrics to assess various aspects of representation quality, that are indicative of capacity, robustness and downstream task performance.
  • PyTorch Integration: Seamlessly integrates with existing PyTorch models and workflows, allowing for straightforward monitoring of learned representations with minimal setup.
  • Open Source: Open for contributions and enhancements from the community, including any new metrics that are proposed.

Reptrix is the perfect tool for machine learning practitioners looking to quantitatively analyze learned representations and enhance the interpretability of their deep learning models, especially models trained in a self-supervised learning framework. To learn more about why these metrics are essential in modern DL workflows, check out our blogpost on Assessing Representation Quality in SSL

List of metrics currently supported

  • $\alpha$-ReQ: This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the $\alpha$ exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative.
  • RankMe: This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.
  • LiDAR: This metric computes the rank of the linear discriminant analysis (LDA) matrix. A higher rank indicates representations with higher degree of seperability among object manifolds.

TODO: Fill out the numbers

ResNet50

Metric Time to compute (s) Memory requirement (GB)
$\alpha$-ReQ 2.400
RankMe 2.364
LiDAR 7.929

ViT

Metric Time to compute (s) Memory requirement (GB)
$\alpha$-ReQ 0.137
RankMe 0.091
LiDAR 0.162

Using Reptrix in your own workflow

  1. Load your favourite pretrained network.
encoder = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
# Remove the final fully connected layer so that the model outputs the 2048 feature vector
encoder = torch.nn.Sequential(*(list(encoder.children())[:-1]))
encoder.eval()
  1. Extract features from the pretrained network.
def get_features(encoder_network, dataloader, transform=None, num_augmentations=10):
    # Loop over the dataset and collect the representations
    all_features = []

    # Loop over the dataset and collect the representations
    for i, data in enumerate(tqdm(dataloader, 0)):
        inputs, _ = data
        if transform:
            inputs = torch.cat([transform(inputs) for _ in range(num_augmentations)], dim=0)
        with torch.no_grad():
            features = encoder_network(inputs)
        if transform:
            # put the augmentations in an additonal dimension
            features = features.reshape(-1, num_augmentations, features.shape[1])
        all_features.append(features)


    # Concatenate all the features
    all_features = torch.cat(all_features, dim=0)
    return all_features

all_representations = get_features(encoder, loader)
num_augmentations = 10
all_representations_lidar = get_features(encoder, loader,
                                transform=transform_augs,
                                num_augmentations=num_augmentations)
num_samples = all_representations_lidar.shape[0]
  1. Compute the representation metrics
from reptrix import alpha, rankme, lidar
metric_alpha = alpha.get_alpha(all_representations)
metric_rankme = rankme.get_rankme(all_representations)
metric_lidar = lidar.get_lidar(all_representations_lidar, num_samples,
                            num_augmentations,
                            del_sigma_augs=0.00001)

Installation

TODO: Update and test this!

Using pypi

You can install the latest version of reptrix using:

pip install reptrix

Manual installation

You can clone this repository and manually install it with:

pip install git+https://github.com/arnab39/reptrix

Setup Conda environment for examples

You can incorporate reptrix in your existing conda environment or create a new environment with the necessary packages:

conda env create -f conda_env.yaml
conda activate reptrix
pip install -e .

Example code for Reptrix

We provide a tutorial iPython notebook that shows how you can incorporate metrics from our Reptrix library to your own code.

Related papers and Citations

This library currently supports metrics proposed in three different papers:

  1. $\alpha$-ReQ : Assessing Representation Quality in Self-Supervised Learning by measuring eigenspectrum decay (NeurIPS 2022
  2. RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised Representations by Their Rank (ICML 2023)
  3. LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures (ICLR 2024)

Contact

For questions related to this code, please raise an issue and you can mail us: Arna Ghosh, Arnab K Mondal, Kumar K Agrawal

Contributing

You can check out the contributor's guide.

This project uses pre-commit, you can install it before making any changes:

pip install pre-commit
cd reptrix
pre-commit install

It is a good idea to update the hooks to the latest version::

pre-commit autoupdate