Skip to content

Commit

Permalink
Add voxelmorph model
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Nov 4, 2024
1 parent adb18fd commit b16eff6
Show file tree
Hide file tree
Showing 3 changed files with 389 additions and 0 deletions.
9 changes: 9 additions & 0 deletions direct/nn/registration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,12 @@ class UnetRegistration2dModelConfig(RegistrationModelConfig):
unet_dropout_probability: float = 0.0
unet_normalized: bool = False
train_end_to_end: bool = True


@dataclass
class VxmDenseModelConfig(RegistrationModelConfig):
inshape: tuple = (512, 246)
nb_unet_features: int = 16
nb_unet_levels: int = 4
nb_unet_conv_per_level: int = 1
int_downsize: int = 2
3 changes: 3 additions & 0 deletions direct/nn/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,20 @@
import torch
import torch.nn as nn

from direct.nn.registration.voxelmorph import VxmDense
from direct.nn.unet.unet_2d import NormUnetModel2d, UnetModel2d
from direct.registration.demons import DemonsFilterType, multiscale_demons_displacement
from direct.registration.optical_flow import OpticalFlowEstimatorType, optical_flow_displacement
from direct.registration.registration import DISCPLACEMENT_FIELD_2D_DIMENSIONS
from direct.registration.warp import warp


__all__ = [
"OpticalFlowILKRegistration2dModel",
"OpticalFlowTVL1Registration2dModel",
"DemonsRegistration2dModel",
"UnetRegistration2dModel",
"VxmDense",
]


Expand Down
Loading

0 comments on commit b16eff6

Please sign in to comment.