diff --git a/direct/nn/registration/config.py b/direct/nn/registration/config.py index b42c2756..ba471312 100644 --- a/direct/nn/registration/config.py +++ b/direct/nn/registration/config.py @@ -50,3 +50,12 @@ class UnetRegistration2dModelConfig(RegistrationModelConfig): unet_dropout_probability: float = 0.0 unet_normalized: bool = False train_end_to_end: bool = True + + +@dataclass +class VxmDenseModelConfig(RegistrationModelConfig): + inshape: tuple = (512, 246) + nb_unet_features: int = 16 + nb_unet_levels: int = 4 + nb_unet_conv_per_level: int = 1 + int_downsize: int = 2 diff --git a/direct/nn/registration/registration.py b/direct/nn/registration/registration.py index 960de235..79e2fb4b 100644 --- a/direct/nn/registration/registration.py +++ b/direct/nn/registration/registration.py @@ -8,17 +8,20 @@ import torch import torch.nn as nn +from direct.nn.registration.voxelmorph import VxmDense from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d from direct.registration.demons import DemonsFilterType, multiscale_demons_displacement from direct.registration.optical_flow import OpticalFlowEstimatorType, optical_flow_displacement from direct.registration.registration import DISCPLACEMENT_FIELD_2D_DIMENSIONS from direct.registration.warp import warp + __all__ = [ "OpticalFlowILKRegistration2dModel", "OpticalFlowTVL1Registration2dModel", "DemonsRegistration2dModel", "UnetRegistration2dModel", + "VxmDense", ] diff --git a/direct/nn/registration/voxelmorph.py b/direct/nn/registration/voxelmorph.py new file mode 100644 index 00000000..f06905ba --- /dev/null +++ b/direct/nn/registration/voxelmorph.py @@ -0,0 +1,377 @@ +from __future__ import annotations + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.distributions.normal import Normal + + +class SpatialTransformer(nn.Module): + """ + N-D Spatial Transformer + """ + + def __init__(self, size, mode="bilinear"): + super().__init__() + + self.size = size + self.mode = mode + + def forward(self, src, flow): + # create sampling grid + vectors = [torch.arange(0, s) for s in self.size] + grids = torch.meshgrid(vectors) + grid = torch.stack(grids) + grid = torch.unsqueeze(grid, 0) + grid = grid.type(torch.FloatTensor).to(src.device) + # new locations + new_locs = grid + flow + shape = flow.shape[2:] + + # need to normalize grid values to [-1, 1] for resampler + for i in range(len(shape)): + new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5) + + # move channels dim to last position + # also not sure why, but the channels need to be reversed + if len(shape) == 2: + new_locs = new_locs.permute(0, 2, 3, 1) + new_locs = new_locs[..., [1, 0]] + elif len(shape) == 3: + new_locs = new_locs.permute(0, 2, 3, 4, 1) + new_locs = new_locs[..., [2, 1, 0]] + + return F.grid_sample(src, new_locs, align_corners=True, mode=self.mode) + + +class VecInt(nn.Module): + """ + Integrates a vector field via scaling and squaring. + """ + + def __init__(self, inshape, nsteps): + super().__init__() + + assert nsteps >= 0, "nsteps should be >= 0, found: %d" % nsteps + self.nsteps = nsteps + self.scale = 1.0 / (2**self.nsteps) + self.transformer = SpatialTransformer(inshape) + + def forward(self, vec): + vec = vec * self.scale + for _ in range(self.nsteps): + vec = vec + self.transformer(vec, vec) + return vec + + +class ResizeTransform(nn.Module): + """ + Resize a transform, which involves resizing the vector field *and* rescaling it. + """ + + def __init__(self, vel_resize, ndims): + super().__init__() + self.factor = 1.0 / vel_resize + self.mode = "linear" + if ndims == 2: + self.mode = "bi" + self.mode + elif ndims == 3: + self.mode = "tri" + self.mode + + def forward(self, x): + if self.factor < 1: + # resize first to save memory + x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode) + x = self.factor * x + + elif self.factor > 1: + # multiply first to save memory + x = self.factor * x + x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode) + + # don't do anything if resize is 1 + return x + + """ + Resize a transform, which involves resizing the vector field *and* rescaling it. + """ + + def __init__(self, vel_resize, ndims): + super().__init__() + self.factor = 1.0 / vel_resize + self.mode = "linear" + if ndims == 2: + self.mode = "bi" + self.mode + elif ndims == 3: + self.mode = "tri" + self.mode + + def forward(self, x): + if self.factor < 1: + # resize first to save memory + x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode) + x = self.factor * x + + elif self.factor > 1: + # multiply first to save memory + x = self.factor * x + x = F.interpolate(x, align_corners=True, scale_factor=self.factor, mode=self.mode) + + # don't do anything if resize is 1 + return x + + +class VoxelmorphUnet(nn.Module): + """A unet architecture for voxelmorph. + + Layer features can be specified directly as a list of encoder and decoder + features or as a single integer along with a number of unet levels. + """ + + def __init__( + self, + inshape: tuple = None, + infeats: int = None, + nb_features: int = None, + nb_levels: int = None, + max_pool: int = 2, + nb_conv_per_level: int = 1, + half_res: bool = False, + ) -> None: + """Inits :class:`VoxelmorphUnet`. + + Parameters + ---------- + inshape : tuple + Input shape. e.g. (192, 192, 192) + infeats : int + Number of input features. + nb_features : int + Unet convolutional features. Can be specified via a list of lists with + the form [[encoder feats], [decoder feats]], or as a single integer. + If None (default), the unet features are defined by the default config described in + the class documentation. + nb_levels : int + Number of levels in unet. Only used when nb_features is an integer. Default is None. + nb_conv_per_level : int + Number of convolutions per unet level. Default is 1. + half_res : bool + Skip the last decoder upsampling. Default is False. + """ + + super().__init__() + + # ensure correct dimensionality + ndims = len(inshape) + assert ndims in [1, 2, 3], "ndims should be one of 1, 2, or 3. found: %d" % ndims + + # cache some parameters + self.half_res = half_res + + enc_nf = [nb_features * (2 ** i) for i in range(nb_levels)] + dec_nf = enc_nf[::-1] + [nb_features] + + enc_nf = [nb_features * (2 ** i) for i in range(nb_levels)] + dec_nf = enc_nf[::-1] + [nb_features] + + nb_dec_convs = len(enc_nf) + final_convs = dec_nf[nb_dec_convs:] + dec_nf = dec_nf[:nb_dec_convs] + self.nb_levels = int(nb_dec_convs / nb_conv_per_level) + 1 + + if isinstance(max_pool, int): + max_pool = [max_pool] * self.nb_levels + + # cache downsampling / upsampling operations + MaxPooling = getattr(nn, "MaxPool%dd" % ndims) + self.pooling = [MaxPooling(s) for s in max_pool] + self.upsampling = [nn.Upsample(scale_factor=s, mode="nearest") for s in max_pool] + + # configure encoder (down-sampling path) + prev_nf = infeats + encoder_nfs = [prev_nf] + self.encoder = nn.ModuleList() + for level in range(self.nb_levels - 1): + convs = nn.ModuleList() + for conv in range(nb_conv_per_level): + nf = enc_nf[level * nb_conv_per_level + conv] + convs.append(ConvBlock(ndims, prev_nf, nf)) + prev_nf = nf + self.encoder.append(convs) + encoder_nfs.append(prev_nf) + + # configure decoder (up-sampling path) + encoder_nfs = np.flip(encoder_nfs) + self.decoder = nn.ModuleList() + for level in range(self.nb_levels - 1): + convs = nn.ModuleList() + for conv in range(nb_conv_per_level): + nf = dec_nf[level * nb_conv_per_level + conv] + convs.append(ConvBlock(ndims, prev_nf, nf)) + prev_nf = nf + self.decoder.append(convs) + if not half_res or level < (self.nb_levels - 2): + prev_nf += encoder_nfs[level] + + # now we take care of any remaining convolutions + self.remaining = nn.ModuleList() + for num, nf in enumerate(final_convs): + self.remaining.append(ConvBlock(ndims, prev_nf, nf)) + prev_nf = nf + + # cache final number of features + self.final_nf = prev_nf + + def forward(self, x) -> torch.Tensor: + """Forward pass of :class:`VoxelmorphUnet`. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Output tensor. + """ + # encoder forward pass + x_history = [x] + for level, convs in enumerate(self.encoder): + for conv in convs: + x = conv(x) + x_history.append(x) + x = self.pooling[level](x) + + # decoder forward pass with upsampling and concatenation + for level, convs in enumerate(self.decoder): + for conv in convs: + x = conv(x) + if not self.half_res or level < (self.nb_levels - 2): + x = self.upsampling[level](x) + # Ensure the dimensions of x and x_history[-1] match for concatenation + diff_dims = [hx - ux for hx, ux in zip(x_history[-1].shape[2:], x.shape[2:])] + pad = [(d // 2, d - d // 2) for d in diff_dims] + pad = [p for sublist in reversed(pad) for p in sublist] # flatten and reverse + x = F.pad(x, pad) + x = torch.cat([x, x_history.pop()], dim=1) + + # remaining convs at full resolution + for conv in self.remaining: + x = conv(x) + + return x + + +class VxmDense(nn.Module): + """ + VoxelMorph network for (unsupervised) nonlinear registration between two images. + """ + + def __init__( + self, + inshape, + nb_unet_features=8, + nb_unet_levels=4, + nb_unet_conv_per_level=1, + warp_num_integration_steps=1, + int_downsize=2, + src_feats=1, + trg_feats=1, + ) -> None: + super().__init__() + + # internal flag indicating whether to return flow or integrated warp during inference + self.training = True + + # ensure correct dimensionality + ndims = len(inshape) + assert ndims in [1, 2, 3], "ndims should be one of 1, 2, or 3. found: %d" % ndims + + # configure core unet model + self.unet_model = VoxelmorphUnet( + inshape, + infeats=(src_feats + trg_feats), + nb_features=nb_unet_features, + nb_levels=nb_unet_levels, + nb_conv_per_level=nb_unet_conv_per_level, + ) + + # configure unet to flow field layer + Conv = getattr(nn, "Conv%dd" % ndims) + self.flow = Conv(self.unet_model.final_nf, ndims, kernel_size=3, padding=1) + + # init flow layer with small weights and bias + self.flow.weight = nn.Parameter(Normal(0, 1e-5).sample(self.flow.weight.shape)) + self.flow.bias = nn.Parameter(torch.zeros(self.flow.bias.shape)) + + # configure optional resize layers (downsize) + if warp_num_integration_steps > 0 and int_downsize > 1: + self.resize = ResizeTransform(int_downsize, ndims) + else: + self.resize = None + + # resize to full res + if warp_num_integration_steps > 0 and int_downsize > 1: + self.fullsize = ResizeTransform(1 / int_downsize, ndims) + else: + self.fullsize = None + + # configure optional integration layer for diffeomorphic warp + down_shape = [int(dim / int_downsize) for dim in inshape] + self.integrate = VecInt(down_shape, warp_num_integration_steps) if warp_num_integration_steps > 0 else None + + # configure transformer + self.transformer = SpatialTransformer(inshape) + + def forward(self, moving_image, reference_image) -> torch.Tensor: + _, seq_len, _, _ = moving_image.shape + + displacement_field = [] + registered_image = [] + for t in range(seq_len): + # concatenate inputs and propagate unet + x = torch.cat([moving_image[:, t : t + 1], reference_image.unsqueeze(1)], dim=1) + x = self.unet_model(x) + + # transform into flow field + flow_field = self.flow(x) + + # resize flow for integration + pos_flow = flow_field + if self.resize: + pos_flow = self.resize(pos_flow) + + # integrate to produce diffeomorphic warp + if self.integrate: + pos_flow = self.integrate(pos_flow) + + # resize to final resolution + if self.fullsize: + pos_flow = self.fullsize(pos_flow) + + displacement_field.append(pos_flow) + registered_image.append(self.transformer(moving_image[:, t : t + 1], pos_flow)) + + displacement_field = torch.stack(displacement_field, dim=1) + registered_image = torch.cat(registered_image, dim=1) + + return registered_image, displacement_field + + +class ConvBlock(nn.Module): + """ + Specific convolutional block followed by leakyrelu for unet. + """ + + def __init__(self, ndims, in_channels, out_channels, stride=1): + super().__init__() + + Conv = getattr(nn, "Conv%dd" % ndims) + self.main = Conv(in_channels, out_channels, 3, stride, 1) + self.activation = nn.LeakyReLU(0.2) + + def forward(self, x): + out = self.main(x) + out = self.activation(out) + return out