CERM is a deep learning framework for training neural networks with constraints. Here, we briefly explain how to use the general framework and run the examples accompanied with our paper "Constrained Empirical Risk Minimization".
Download CERM directly from the repository. The CERM framework has been tested on PyTorch 2.0.1.
The CERM framework preserves the usual flow of building models in PyTorch as much as possible. Here we provide a brief example of how one would use our framework to build a model with constraints.
We provide an abstract class Constraints
whose implementation should be completed by the user. The user is only required to implement their specific constraint of interest.
A specific constraint can be applied to different groups of parameters, assuming each group has the same dimensionality. We require the user to provide the following data to the constructor of our Constraint
class:
- num_params: int
- dimension input of zero map
- num_eqs: int
- number of equations
- num_groups: int
- the same set of equations can be applied to different groups (layers)
Next, we build our model as usual extending the torch.nn.Module
class,
with the only difference that we use ConstrainedParameter
instead of torch.nn.Parameter
in the places where we wish to use constraints.
The constructor of ConstrainedParameter
requires the following data:
constraint
: an instance of theConstraint
class.init_params
(optional): initial guess parameters.
The constructor of the ConstrainedParameter
will refine the initial guess and constrain it to the constrained manifold. A constrained parameter is
explicitly constructed using
from cerm.network.constrained_params import ConstrainedParameter
constrained_params = ContrainedParameter(constraint=constraint, init_params=params)
Models are trained by using a custom optimizer for the constrained parameters, which
can be found in /CERM/cerm/optimizer
, and works like any standard optimizer in PyTorch.
When writing training routines, one should use the split_params
function first to split
the learnable parameteres in the model into constrained and unconstrained ones; see
the example below.
import torch
from cerm.optimizer.riemannian_sgd import RSGD
from cerm.network.constrained_params import split_params
# Split parameters in your model (model is an instance of torch.nn.Module)
unconstrained_params, constrained_params = split_params(model)
# Initialize optimizer
constrained_optimizer = RSGD(unconstrained_params, lr=1e-03)
# Now use optimizer as usual to update constrained parameters
In /CERM/cerm/examples/sphere
we define a fully connected layer, where the rows of the
weight matrix are constrained to the unit sphere. This example serves only as a minimal toy example
illustrating how to use our framework.
import torch
from cerm.examples.sphere import spherical_constraints
# Construct MLP with spherical constraints
dim_in = 64
dim_out = 5
dim_latent = 96
num_hidden_layers = 4
mlp = spherical_constraints.MLP(dim_in, dim_latent, dim_out, num_hidden_layers)
# Evaluate
bsize = 6
x = torch.rand(bsize, dim_in)
y = mlp(x)
In /CERM/cerm/examples/stiefel
we define a fully connected layer, where the weight matrix is
constrained to lie on the so-called Stiefel manifold. That is, the columns of the weight matrix form
an orthonormal set.
import torch
from cerm.examples.stiefel import stiefel
# Initialize Stiefel layer
dim_in = 10
dim_out = 15
stiefel_layer = stiefel.StiefelLayer(dim_in, dim_out, bias=True)
# Apply model
bsize = 16
x = torch.rand(bsize, dim_in)
y = stiefel_layer(x)
This setup also contains an example of a domain-specific method to initialize parameters randomly (points on the Stiefel manifold).
In /CERM/cerm/examples/equivariance
we define a fully connected layer which is equivariant to
circular shifts. The only purpose of this example is to illustrate the flexibility of our framework. We know, of course, precisely what architecture is needed for achieving shift equivariance (CNNs).
import torch
import logging
from torch import Tensor
from cerm.examples.equivariance.shift import ShiftEquivariantLayer
# Module logger
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def left_shift(x: Tensor, num_shifts: int) -> Tensor:
"""Shift sequence to the left (circular).
Parameters
----------
x: Tensor
batch of 1d sequences (shape [batch_size len_seq])
num_shifts:
number of shifts to the left
Returns
-------
y: Tensor
input sequence shifted to the left (shape [batch_size len_seq])
"""
len_seq = x.shape[1]
assert num_shifts < len_seq, "Number of shifts exceeds sequence length"
idx = (torch.arange(len_seq) + num_shifts) % len_seq
return x[:, idx]
# Initialize shift equivariant layer
dim = 35
bsize = 5
layer = ShiftEquivariantLayer(dim)
# Test equivariance
num_shifts = 7
x = 100 * torch.rand(bsize, dim)
y1 = left_shift(layer(x), num_shifts)
y2 = layer(left_shift(x, num_shifts))
logger.info(f"Discrepancy equivariance: {torch.max(torch.abs(y1 - y2))}")
A more elaborate example implementing learnable wavelet layers can be found in /CERM/cerm/examples/wavelets
In this section we summarize the details of how to use our wavelet-layers using the CERM-framework. The constructor of the 1d wavelet layer requires the following data:
order
: int
- order of filter
number_of_decomps
: int
- number of levels in wavelet decompositions
num_filters_per_channel
: int (optional)
- number of wavelet filters
num_channels
: int (optional)
- number of channels in input signal
periodic_signal
: bool (optional)
- indicates whether the input signal is periodic
A minimal example using a wavelet layer is given below
import torch
from cerm.examples.wavelets.wavelet_layer import WaveletLayer1d
# Construct learnable wavelet layer
order = 4
num_levels_down = 3
num_channels = 2
wavelet_layer = WaveletLayer1d(order, num_levels_down, num_channels=num_channels)
# Compute decomposition
bsize = 4
signal_len = 157
signal = torch.rand(bsize, num_channels, signal_len)
approx, detail = wavelet_layer(signal)
The examples from the paper can be run using the supplied hydra
configs.
We provide a conda environment in /CERM/cerm/examples/mra_segmentation/conda_env.yml
to install the required dependencies:
conda env create -f /CERM/cerm/examples/mra_segmentation/conda_env.yml
For training purposes the scans and associated segmentations need to be stored in h5
format.
A scan and its associated masks need to be stored in one folder. The folder with training data should consist of subfolders containing nrrd
files. Each subfolder corresponds to a separate scan with associated segmentations. The filenames of a scan and its associated masks need to be identical for each subfolder. For example, the folder setup for a training set may look as follows:
train/
├── image_masks_1/
├── scan.nrrd
├── mask_1.nrrd
├── mask_2.nrrd
├── mask_3.nrrd
├── image_masks_2/
├── scan.h5
├── mask_1.nrrd
├── mask_2.nrrd
├── mask_3.nrrd
├── ...
├── image_masks_n/
├── scan.nrrd
├── mask_1.nrrd
├── mask_2.nrrd
├── mask_3.nrrd
The folder setup for a validation set should follow the same structure.
After the appropriate folder structures have been set up, the contents of each folder need to be converted to a single h5
dataset. The names of the scans and masks will be used as keys. After conversion the final folder-structure should be as depicted below:
train/
├── image_masks_1/
├── scan_with_masks.h5
├── image_masks_2/
├── scan_with_masks.h5
├── ...
├── image_masks_n/
├── scan_with_masks.h5
A script nrrd_to_h5.py
for performing the conversion to h5
is available in the tools folder. If the folders with data are structured as prescribed in part I, the following call will set up the required h5
datasets.
python nrrd_to_h5.py $nrrd_dir $h5_dir
Models can be trained using the provided configs, e.g., the models for spleen can be trained using
python main.py --multirun
task=spleen \
dataset.train_dir=/path/to/train_dir \
dataset.val_dir=/path/to/val_dir \
dataset.test_dir=/path/to/test_dir \
network.decoder.order_wavelet=3,4,5,6,7,8
We refer the reader to /CERM/cerm/examples/mra_segmentation/mra/configs
for configuration details, and what settings can be overridden, and /CERM/cerm/examples/mra_segmentation/mra/experiments
for bash scripts containing detailed examples to reproduce the results presented in the paper.