diff --git a/diffdrr/registration.py b/diffdrr/registration.py index bf1d0707c..0f7f7f76b 100644 --- a/diffdrr/registration.py +++ b/diffdrr/registration.py @@ -93,7 +93,9 @@ def forward(self, x): x = self.backbone(x) rot = self.rot_regression(x) xyz = self.xyz_regression(x) - return rot, xyz + return convert( + rot, xyz, convention=self.convention, parameterization=self.parameterization + ) # %% ../notebooks/api/08_registration.ipynb 7 N_ANGULAR_COMPONENTS = { diff --git a/notebooks/api/08_registration.ipynb b/notebooks/api/08_registration.ipynb index e7cd9e4b9..46e82d6f7 100644 --- a/notebooks/api/08_registration.ipynb +++ b/notebooks/api/08_registration.ipynb @@ -163,7 +163,9 @@ " x = self.backbone(x)\n", " rot = self.rot_regression(x)\n", " xyz = self.xyz_regression(x)\n", - " return rot, xyz" + " return convert(\n", + " rot, xyz, convention=self.convention, parameterization=self.parameterization\n", + " )" ] }, {