Skip to content

Commit

Permalink
Allow for non end to end training
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Oct 24, 2024
1 parent 5c2fea3 commit adb18fd
Show file tree
Hide file tree
Showing 4 changed files with 73 additions and 22 deletions.
39 changes: 34 additions & 5 deletions direct/nn/mri_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,16 @@ def _do_iteration(
output_kspace: TensorOrNone

with autocast(enabled=self.mixed_precision):

if self.ndim == 3 and "registration_model" in self.models:
# Freeze registration model weights
if self.cfg.additional_models.registration_model.train_end_to_end:
if len(list(self.models["registration_model"].parameters())) > 0:
for param in self.models["registration_model"].parameters():
param.requires_grad = False

data["sensitivity_map"] = self.compute_sensitivity_map(data["sensitivity_map"])
data = self.perform_sampling(data)

output_image, output_kspace = self.forward_function(data)
output_image = T.modulus_if_complex(output_image, complex_axis=self._complex_dim)
Expand All @@ -137,9 +146,33 @@ def _do_iteration(
k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in regularizer_fns.keys()
}

loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, output_image, output_kspace
)

if self.ndim == 3 and "registration_model" in self.models:

if self.cfg.additional_models.registration_model.train_end_to_end:
if len(list(self.models["registration_model"].parameters())) > 0:
for param in self.models["registration_model"].parameters():
param.requires_grad = True
for param in self.model.parameters():
param.requires_grad = False
for model in self.models:
if model != "registration_model":
for param in self.models[model].parameters():
param.requires_grad = False

# Perform registration and compute loss on registered image and displacement field
registered_image, displacement_field = self.do_registration(data, output_image)
registered_image, displacement_field = self.do_registration(
data,
(
output_image.detach()
if self.cfg.additional_models.registration_model.train_end_to_end
else output_image
),
)

# If DL-based model calculate loss
if len(list(self.models["registration_model"].parameters())) > 0:
Expand All @@ -166,10 +199,6 @@ def _do_iteration(
output_displacement_field=displacement_field,
target_displacement_field=target_displacement_field,
)
loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace)
regularizer_dict = self.compute_loss_on_data(
regularizer_dict, regularizer_fns, data, output_image, output_kspace
)

loss = sum(loss_dict.values()) + sum(regularizer_dict.values()) # type: ignore

Expand Down
1 change: 1 addition & 0 deletions direct/nn/registration/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,4 @@ class UnetRegistration2dModelConfig(RegistrationModelConfig):
unet_num_pool_layers: int = 4
unet_dropout_probability: float = 0.0
unet_normalized: bool = False
train_end_to_end: bool = True
1 change: 1 addition & 0 deletions direct/nn/registration/registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ def __init__(
unet_dropout_probability: float = 0.0,
unet_normalized: bool = False,
warp_num_integration_steps: int = 1,
**kwargs,
) -> None:
"""Inits :class:`UnetRegistration2dModel`.
Expand Down
54 changes: 37 additions & 17 deletions direct/nn/vsharp/vsharp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,14 +108,49 @@ def _do_iteration(
output_kspace: TensorOrNone

with autocast(enabled=self.mixed_precision):

if "registration_model" in self.models:
# Freeze registration model weights
if self.cfg.additional_models.registration_model.train_end_to_end:
if len(list(self.models["registration_model"].parameters())) > 0:
for param in self.models["registration_model"].parameters():
param.requires_grad = False

output_images, output_kspace = self.forward_function(data)
output_images = [T.modulus_if_complex(_, complex_axis=self._complex_dim) for _ in output_images]

loss_dict = {k: torch.tensor([0.0], dtype=data["target"].dtype).to(self.device) for k in loss_fns.keys()}

auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0])
for i, output_image in enumerate(output_images):
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_image=output_image,
output_kspace=None,
weight=auxiliary_loss_weights[i],
)
# Compute loss on k-space
loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, output_image=None, output_kspace=output_kspace
)

if "registration_model" in self.models:
# Perform registration and compute loss on registered image and displacement field
registered_image, displacement_field = self.do_registration(data, output_images[-1])
if self.cfg.additional_models.registration_model.train_end_to_end:
if len(list(self.models["registration_model"].parameters())) > 0:
for param in self.models["registration_model"].parameters():
param.requires_grad = True
for param in self.model.parameters():
param.requires_grad = False
for model in self.models:
if model != "registration_model":
for param in self.models[model].parameters():
param.requires_grad = False
# Perform registration and compute loss on registered image and displacement field
registered_image, displacement_field = self.do_registration(data, output_images[-1].detach())
else:
registered_image, displacement_field = self.do_registration(data, output_images[-1])

# If DL-based model calculate loss
if len(list(self.models["registration_model"].parameters())) > 0:
Expand All @@ -139,21 +174,6 @@ def _do_iteration(
target_displacement_field=data["displacement_field"],
)

auxiliary_loss_weights = torch.logspace(-1, 0, steps=len(output_images)).to(output_images[0])
for i, output_image in enumerate(output_images):
loss_dict = self.compute_loss_on_data(
loss_dict,
loss_fns,
data,
output_image=output_image,
output_kspace=None,
weight=auxiliary_loss_weights[i],
)
# Compute loss on k-space
loss_dict = self.compute_loss_on_data(
loss_dict, loss_fns, data, output_image=None, output_kspace=output_kspace
)

loss = sum(loss_dict.values()) # type: ignore

if self.model.training:
Expand Down

0 comments on commit adb18fd

Please sign in to comment.