Reptrix, short for Representation Metrics, is a PyTorch library designed to simplify the evaluation of representation quality metrics in pretrained deep neural networks. Reptrix offers a suite of recently proposed metrics, predominanty in the vision self-supervised learning literature, that are essential for researchers and engineers focusing on design, deployment, evaluation and interpretability of deep neural networks in computer vision settings.
Key Features:
- Comprehensive Metric Suite: Includes a variety of metrics to assess various aspects of representation quality, that are indicative of capacity, robustness and downstream task performance.
- PyTorch Integration: Seamlessly integrates with existing PyTorch models and workflows, allowing for straightforward monitoring of learned representations with minimal setup.
- Open Source: Open for contributions and enhancements from the community, including any new metrics that are proposed.
Reptrix is the perfect tool for machine learning practitioners looking to quantitatively analyze learned representations and enhance the interpretability of their deep learning models, especially models trained in a self-supervised learning framework. To learn more about why these metrics are essential in modern DL workflows, check out our blogpost on Assessing Representation Quality in SSL
-
$\alpha$ -ReQ: This metric computes the eigenvalues of the covariance matrix of the representations and fits a power-law distribution to them. The exponent of the power-law distribution is called the$\alpha$ exponent, which measures the heavy-tailedness of the distribution. A lower alpha exponent indicates that the representations are more discriminative. - RankMe: This metric computes the rank of the covariance matrix of the representations. A higher rank indicates representations of higher capacity.
- LiDAR: This metric computes the rank of the linear discriminant analysis (LDA) matrix. A higher rank indicates representations with higher degree of seperability among object manifolds.
TODO: Fill out the numbers
ResNet50
Metric | Time to compute (s) | Memory requirement (GB) |
---|---|---|
|
2.400 | |
RankMe | 2.364 | |
LiDAR | 7.929 |
ViT
Metric | Time to compute (s) | Memory requirement (GB) |
---|---|---|
|
0.137 | |
RankMe | 0.091 | |
LiDAR | 0.162 |
- Load your favourite pretrained network.
encoder = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')
# Remove the final fully connected layer so that the model outputs the 2048 feature vector
encoder = torch.nn.Sequential(*(list(encoder.children())[:-1]))
encoder.eval()
- Extract features from the pretrained network.
def get_features(encoder_network, dataloader, transform=None, num_augmentations=10):
# Loop over the dataset and collect the representations
all_features = []
# Loop over the dataset and collect the representations
for i, data in enumerate(tqdm(dataloader, 0)):
inputs, _ = data
if transform:
inputs = torch.cat([transform(inputs) for _ in range(num_augmentations)], dim=0)
with torch.no_grad():
features = encoder_network(inputs)
if transform:
# put the augmentations in an additonal dimension
features = features.reshape(-1, num_augmentations, features.shape[1])
all_features.append(features)
# Concatenate all the features
all_features = torch.cat(all_features, dim=0)
return all_features
all_representations = get_features(encoder, loader)
num_augmentations = 10
all_representations_lidar = get_features(encoder, loader,
transform=transform_augs,
num_augmentations=num_augmentations)
num_samples = all_representations_lidar.shape[0]
- Compute the representation metrics
from reptrix import alpha, rankme, lidar
metric_alpha = alpha.get_alpha(all_representations)
metric_rankme = rankme.get_rankme(all_representations)
metric_lidar = lidar.get_lidar(all_representations_lidar, num_samples,
num_augmentations,
del_sigma_augs=0.00001)
TODO: Update and test this!
You can install the latest version of reptrix using:
pip install reptrix
You can clone this repository and manually install it with:
pip install git+https://github.com/arnab39/reptrix
You can incorporate reptrix in your existing conda environment or create a new environment with the necessary packages:
conda env create -f conda_env.yaml
conda activate reptrix
pip install -e .
We provide a tutorial iPython notebook that shows how you can incorporate metrics from our Reptrix library to your own code.
This library currently supports metrics proposed in three different papers:
-
$\alpha$ -ReQ : Assessing Representation Quality in Self-Supervised Learning by measuring eigenspectrum decay (NeurIPS 2022 - RankMe: Assessing the Downstream Performance of Pretrained Self-Supervised Representations by Their Rank (ICML 2023)
- LiDAR: Sensing Linear Probing Performance in Joint Embedding SSL Architectures (ICLR 2024)
For questions related to this code, please raise an issue and you can mail us: Arna Ghosh, Arnab K Mondal, Kumar K Agrawal
You can check out the contributor's guide.
This project uses pre-commit
, you can install it before making any changes:
pip install pre-commit
cd reptrix
pre-commit install
It is a good idea to update the hooks to the latest version::
pre-commit autoupdate