Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix translation in double geodesic formula #310

Merged
merged 1 commit into from
Jul 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,6 @@
'diffdrr/metrics.py'),
'diffdrr.metrics.DoubleGeodesicSE3.forward': ( 'api/metrics.html#doublegeodesicse3.forward',
'diffdrr/metrics.py'),
'diffdrr.metrics.GeodesicSO3': ('api/metrics.html#geodesicso3', 'diffdrr/metrics.py'),
'diffdrr.metrics.GeodesicSO3.__init__': ('api/metrics.html#geodesicso3.__init__', 'diffdrr/metrics.py'),
'diffdrr.metrics.GeodesicSO3.forward': ('api/metrics.html#geodesicso3.forward', 'diffdrr/metrics.py'),
'diffdrr.metrics.GeodesicTranslation': ('api/metrics.html#geodesictranslation', 'diffdrr/metrics.py'),
'diffdrr.metrics.GeodesicTranslation.__init__': ( 'api/metrics.html#geodesictranslation.__init__',
'diffdrr/metrics.py'),
'diffdrr.metrics.GeodesicTranslation.forward': ( 'api/metrics.html#geodesictranslation.forward',
'diffdrr/metrics.py'),
'diffdrr.metrics.GradientNormalizedCrossCorrelation2d': ( 'api/metrics.html#gradientnormalizedcrosscorrelation2d',
'diffdrr/metrics.py'),
'diffdrr.metrics.GradientNormalizedCrossCorrelation2d.__init__': ( 'api/metrics.html#gradientnormalizedcrosscorrelation2d.__init__',
Expand Down
56 changes: 14 additions & 42 deletions diffdrr/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def forward(self, x1, x2):
return super().forward(self.sobel(x1), self.sobel(x2))

# %% ../notebooks/api/05_metrics.ipynb 14
from .pose import RigidTransform, convert, so3_log_map
from .pose import RigidTransform, convert


class LogGeodesicSE3(torch.nn.Module):
Expand All @@ -119,40 +119,10 @@ def forward(
) -> Float[torch.Tensor, "b"]:
return pose_2.compose(pose_1.inverse()).get_se3_log().norm(dim=1)

# %% ../notebooks/api/05_metrics.ipynb 16
class GeodesicSO3(torch.nn.Module):
"""Calculate the angular distance between two rotations in SO(3)."""
# %% ../notebooks/api/05_metrics.ipynb 17
from .pose import so3_log_map

def __init__(self):
super().__init__()

def forward(
self,
pose_1: RigidTransform,
pose_2: RigidTransform,
) -> Float[torch.Tensor, "b"]:
r1 = pose_1.matrix[..., :3, :3]
r2 = pose_2.matrix[..., :3, :3]
rdiff = r1.transpose(-1, -2) @ r2
return so3_log_map(rdiff).norm(dim=-1)


class GeodesicTranslation(torch.nn.Module):
"""Calculate the angular distance between two translations in R^3."""

def __init__(self):
super().__init__()

def forward(
self,
pose_1: RigidTransform,
pose_2: RigidTransform,
) -> Float[torch.Tensor, "b"]:
t1 = pose_1.matrix[..., :3, 3]
t2 = pose_2.matrix[..., :3, 3]
return (t1 - t2).norm(dim=1)

# %% ../notebooks/api/05_metrics.ipynb 18
class DoubleGeodesicSE3(torch.nn.Module):
"""
Calculate the angular and translational geodesics between two SE(3) transformation matrices.
Expand All @@ -161,19 +131,21 @@ class DoubleGeodesicSE3(torch.nn.Module):
def __init__(
self,
sdd: float, # Source-to-detector distance
eps: float = 1e-4, # Avoid overflows in sqrt
eps: float = 1e-6, # Avoid overflows in sqrt
):
super().__init__()
self.sdr = sdd / 2
self.eps = eps

self.rotation = GeodesicSO3()
self.translation = GeodesicTranslation()
self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(
r1.transpose(-1, -2) @ r2
).norm(dim=-1)
self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1)

def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):
angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)
translation_geodesic = self.translation(pose_1, pose_2)
double_geodesic = (
(angular_geodesic).square() + translation_geodesic.square() + self.eps
).sqrt()
return angular_geodesic, translation_geodesic, double_geodesic
r1, t1 = pose_1.convert("matrix")
r2, t2 = pose_2.convert("matrix")
rot = self.rot_geo(r1, r2)
xyz = self.xyz_geo(t1, t2)
dou = (rot.square() + xyz.square() + self.eps).sqrt()
return rot, xyz, dou
72 changes: 17 additions & 55 deletions notebooks/api/05_metrics.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@
{
"data": {
"text/plain": [
"tensor([ 0.0060, -0.0284, -0.0047, -0.0036, -0.0012, -0.0156, 0.0054, 0.0014])"
"tensor([-0.0019, -0.0004, 0.0035, -0.0198, -0.0078, -0.0175, 0.0171, 0.0019])"
]
},
"execution_count": null,
Expand Down Expand Up @@ -296,7 +296,7 @@
"outputs": [],
"source": [
"#| export\n",
"from diffdrr.pose import RigidTransform, convert, so3_log_map\n",
"from diffdrr.pose import RigidTransform, convert\n",
"\n",
"\n",
"class LogGeodesicSE3(torch.nn.Module):\n",
Expand Down Expand Up @@ -352,47 +352,6 @@
"geodesic_se3(pose_1, pose_2)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3d034430-c6fd-4d9a-8bb6-180a1be9b68e",
"metadata": {},
"outputs": [],
"source": [
"#| exporti\n",
"class GeodesicSO3(torch.nn.Module):\n",
" \"\"\"Calculate the angular distance between two rotations in SO(3).\"\"\"\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" def forward(\n",
" self,\n",
" pose_1: RigidTransform,\n",
" pose_2: RigidTransform,\n",
" ) -> Float[torch.Tensor, \"b\"]:\n",
" r1 = pose_1.matrix[..., :3, :3]\n",
" r2 = pose_2.matrix[..., :3, :3]\n",
" rdiff = r1.transpose(-1, -2) @ r2\n",
" return so3_log_map(rdiff).norm(dim=-1)\n",
"\n",
"\n",
"class GeodesicTranslation(torch.nn.Module):\n",
" \"\"\"Calculate the angular distance between two translations in R^3.\"\"\"\n",
"\n",
" def __init__(self):\n",
" super().__init__()\n",
"\n",
" def forward(\n",
" self,\n",
" pose_1: RigidTransform,\n",
" pose_2: RigidTransform,\n",
" ) -> Float[torch.Tensor, \"b\"]:\n",
" t1 = pose_1.matrix[..., :3, 3]\n",
" t2 = pose_2.matrix[..., :3, 3]\n",
" return (t1 - t2).norm(dim=1)"
]
},
{
"cell_type": "markdown",
"id": "bfdc8eba-cc2b-4ac8-a6d5-99f4d3715c2a",
Expand Down Expand Up @@ -423,6 +382,9 @@
"outputs": [],
"source": [
"#| export\n",
"from diffdrr.pose import so3_log_map\n",
"\n",
"\n",
"class DoubleGeodesicSE3(torch.nn.Module):\n",
" \"\"\"\n",
" Calculate the angular and translational geodesics between two SE(3) transformation matrices.\n",
Expand All @@ -431,22 +393,22 @@
" def __init__(\n",
" self,\n",
" sdd: float, # Source-to-detector distance\n",
" eps: float = 1e-4, # Avoid overflows in sqrt\n",
" eps: float = 1e-6, # Avoid overflows in sqrt\n",
" ):\n",
" super().__init__()\n",
" self.sdr = sdd / 2\n",
" self.eps = eps\n",
"\n",
" self.rotation = GeodesicSO3()\n",
" self.translation = GeodesicTranslation()\n",
" self.rot_geo = lambda r1, r2: self.sdr * so3_log_map(r1.transpose(-1, -2) @ r2).norm(dim=-1)\n",
" self.xyz_geo = lambda t1, t2: (t1 - t2).norm(dim=-1)\n",
"\n",
" def forward(self, pose_1: RigidTransform, pose_2: RigidTransform):\n",
" angular_geodesic = self.sdr * self.rotation(pose_1, pose_2)\n",
" translation_geodesic = self.translation(pose_1, pose_2)\n",
" double_geodesic = (\n",
" (angular_geodesic).square() + translation_geodesic.square() + self.eps\n",
" ).sqrt()\n",
" return angular_geodesic, translation_geodesic, double_geodesic"
" r1, t1 = pose_1.convert(\"matrix\")\n",
" r2, t2 = pose_2.convert(\"matrix\")\n",
" rot = self.rot_geo(r1, r2)\n",
" xyz = self.xyz_geo(t1, t2)\n",
" dou = (rot.square() + xyz.square() + self.eps).sqrt()\n",
" return rot, xyz, dou"
]
},
{
Expand All @@ -458,7 +420,7 @@
{
"data": {
"text/plain": [
"(tensor([25.5000]), tensor([1.7321]), tensor([25.5588]))"
"(tensor([51.0000]), tensor([1.7321]), tensor([51.0294]))"
]
},
"execution_count": null,
Expand All @@ -468,7 +430,7 @@
],
"source": [
"# Angular distance and translational distance both in mm\n",
"double_geodesic = DoubleGeodesicSE3(1020 / 2)\n",
"double_geodesic = DoubleGeodesicSE3(1020.0)\n",
"\n",
"pose_1 = convert(\n",
" torch.tensor([[0.1, 1.0, torch.pi]]),\n",
Expand Down Expand Up @@ -502,7 +464,7 @@
{
"cell_type": "code",
"execution_count": null,
"id": "f043c1d4",
"id": "538daf03-4a98-405e-9ddf-448b3c831af7",
"metadata": {},
"outputs": [],
"source": []
Expand Down
Loading