Skip to content

Public code for NeurIPS 2024 paper on enhancing diversity in bayesian deep learning via hyperspherical energy minimization of CKA

License

Notifications You must be signed in to change notification settings

Deep-Machine-Vision/he-cka-ensembles

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

8 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Enhancing Diversity in Bayesian Deep Learning via Hyperspherical Energy Minimization of CKA

This repository is the official implementation of related to our paper published in NeurIPS 2024.

Requirements

Setup the environment and install requirements:

mamba env create --name he-cka-ensemble --file environment.yml

To install mamba, a super fast drop in replacement for conda, please read the following installation guide. If you prefer to use Anaconda just replace mamba with conda.

Datasets

Most datasets will automatically be downloaded to the data/ folder in the repo directory.

Basic Usage

A basic ensemble, and feature tracking, can now be easily implemented with this library. More in-depth examples with

import torch

from hyper.generators.ensemble import FixedEnsembleModel as Ensemble
from hyper.layers import Conv2d, Linear, Flatten, SequentialModule as Sequential
from hyper.util.collections import flatten_keys


# construct a simple CNN for MNIST (28x28x1) input
# all modules in hyper.layers support batched weights
lenet = Sequential(
    Conv2d(1, 6, 5, act='relu', pooling='max', track='detach'),  # track features but detach from graph
    Conv2d(6, 16, 5, act='relu', pooling='max'),
    Flatten(track=False),  # track=False => do not track feature
    Linear(256, 120, act='relu'),
    Linear(120, 84, act='relu'),
    Linear(84, 10)
)

# create ensemble or equivalently a hypernetwork using MLPLayerModelGenerator in generators.mlp
ens = Ensemble(
    target=lenet,
    ensemble_size=5  # make an ensemble of 5 members
)

# test input
X = torch.randn((8, 1, 28, 28))

# feed through all 5 ensemble members
feat, Y = ens(5, X)
print(f'Output shape {Y.shape}') # output of [5, 8, 10] or [models, batch size, classes]


# flatten keys turns nested ordered dictionaries into a single dictionary with a separator '.' 
for mod_name, mod_feat in flatten_keys(feat).items():
    if mod_feat is not None:  # if not tracking features are None
        print(f'Module {mod_name} shape {mod_feat.shape} variance {mod_feat.var()}')

Training

Training all experiments use the train.py script and all the configurations listed in configs/.

Example usage for the mnist ensemble is

python train.py configs/mnist/ens/ensemble.yml

Checkpoints and plots are saved to the outputs/ folder.

Use --help for more options such as multiple runs --runs or wandb logging --wandb.

Evaluation

To evaluate the models and handle relevant plotting use eval.py, which accepts three arguments

python eval.py <training configuration file> <experiment type: [mnist/cifar10]> <weight file name>.pt

Example usage on an mnist ensemble, evaluating the model checkpoint at epoch 50.

python eval.py configs/mnist/ens/ensemble.yml mnist model-50.pt

TODO

  • Rewrite hypernet and training modules
  • Rewrite Toy experiments
  • Rewrite MNIST experiments
  • Rewrite ResNet18
  • CIFAR10 experiments (see configs/cifar10/ens/test/README.md)
  • Build ResNet32 model
  • Rewrite CIFAR100 experiments
  • Include hypernetwork examples/documentation
  • Fix seeding issue with ood examples

About

Public code for NeurIPS 2024 paper on enhancing diversity in bayesian deep learning via hyperspherical energy minimization of CKA

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published