From d7ca4a5b1121a2d4d0bd14be70d4d32b7d432390 Mon Sep 17 00:00:00 2001 From: George Yiasemis Date: Sun, 20 Oct 2024 22:28:53 +0200 Subject: [PATCH] Minor fix --- direct/nn/mri_models.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/direct/nn/mri_models.py b/direct/nn/mri_models.py index fc549601..e0541209 100644 --- a/direct/nn/mri_models.py +++ b/direct/nn/mri_models.py @@ -155,12 +155,16 @@ def _do_iteration( else data["reference_image"].tile((1, registered_image.shape[1], *([1] * len(shape[1:])))) ), ) + if "displacement_field" in data: + target_displacement_field = data["displacement_field"] + else: + target_displacement_field = None loss_dict = self.compute_loss_on_data( loss_dict, loss_fns, data, output_displacement_field=displacement_field, - target_displacement_field=data["displacement_field"], + target_displacement_field=target_displacement_field, ) loss_dict = self.compute_loss_on_data(loss_dict, loss_fns, data, output_image, output_kspace) regularizer_dict = self.compute_loss_on_data( @@ -1347,9 +1351,7 @@ def compute_loss_on_data( elif "displacement_field" in key: if output_displacement_field is not None: output = output_displacement_field - target = ( - data["displacement_field"] if target_displacement_field is None else target_displacement_field - ) + target = target_displacement_field reconstruction_size = data.get("reconstruction_size", None) else: continue