From d912b9e35b3e8da770f26967bbd9b62971bebc1a Mon Sep 17 00:00:00 2001 From: georgeyiasemis Date: Wed, 21 Aug 2024 17:31:30 +0200 Subject: [PATCH] Minor fix --- direct/nn/registration/registration.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/direct/nn/registration/registration.py b/direct/nn/registration/registration.py index 455bd7e3..9a50c778 100644 --- a/direct/nn/registration/registration.py +++ b/direct/nn/registration/registration.py @@ -114,7 +114,7 @@ def forward(self, moving_image: torch.Tensor, reference_image: torch.Tensor) -> # Estimate the displacement field displacement_field = [ - self.displacement_transform(reference_image[_].cpu(), moving_image[_].cpu()) + self.displacement_transform(reference_image[_].detach().cpu(), moving_image[_].detach().cpu()) for _ in range(moving_image.shape[0]) ] displacement_field = torch.stack(displacement_field, dim=0)