Skip to content

Commit

Permalink
Provide simpler access to camera intrinsics for optimization (#240)
Browse files Browse the repository at this point in the history
* Reparameterize camera intrinsics

* Fix CT partitioning from labelmap

* Rerun timings following changes to detector
  • Loading branch information
eigenvivek authored May 20, 2024
1 parent 5a88775 commit 53a149a
Show file tree
Hide file tree
Showing 8 changed files with 103 additions and 53 deletions.
2 changes: 2 additions & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
'diffdrr.detector.Detector.__init__': ('api/detector.html#detector.__init__', 'diffdrr/detector.py'),
'diffdrr.detector.Detector._initialize_carm': ( 'api/detector.html#detector._initialize_carm',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.calibration': ( 'api/detector.html#detector.calibration',
'diffdrr/detector.py'),
'diffdrr.detector.Detector.forward': ('api/detector.html#detector.forward', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.intrinsic': ('api/detector.html#detector.intrinsic', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.reorient': ('api/detector.html#detector.reorient', 'diffdrr/detector.py')},
Expand Down
2 changes: 1 addition & 1 deletion diffdrr/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def read(
mask = torch.any(
torch.stack([subject.mask.data.squeeze() == idx for idx in labels]), dim=0
)
subject.density = subject.density * mask
subject.density.data = subject.density.data * mask

return subject

Expand Down
61 changes: 42 additions & 19 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
):
super().__init__()
self.sdd = sdd
# self.sdd = sdd
self.height = height
self.width = width
self.delx = delx
self.dely = dely
self.x0 = x0
self.y0 = y0
# self.delx = delx
# self.dely = dely
# self.x0 = x0
# self.y0 = y0
self.n_subsample = n_subsample
if self.n_subsample is not None:
self.subsamples = []
Expand All @@ -52,20 +52,39 @@ def __init__(
# Create a pose to reorient the scanner
self.register_buffer("_reorient", reorient)

# Create a calibration matrix that holds the detector's intrinsic parameters
self.register_buffer(
"_calibration",
torch.tensor(
[
[delx, 0, 0, x0],
[0, dely, 0, y0],
[0, 0, sdd, 0],
[0, 0, 0, 1],
]
),
)

@property
def reorient(self):
return RigidTransform(self._reorient)

@property
def calibration(self):
"""A 4x4 matrix that rescales the detector plane to world coordinates."""
return RigidTransform(self._calibration)

@property
def intrinsic(self):
"""The 3x3 intrinsic matrix."""
return make_intrinsic_matrix(
self.sdd,
self.delx,
self.dely,
self._calibration[2, 2].item(),
self._calibration[0, 0].item(),
self._calibration[1, 1].item(),
self.height,
self.width,
self.x0,
self.y0,
self._calibration[0, -1].item(),
self._calibration[1, -1].item(),
).to(self.source)

# %% ../notebooks/api/02_detector.ipynb 6
Expand All @@ -79,7 +98,7 @@ def _initialize_carm(self: Detector):

# Initialize the source at the origin and the center of the detector plane on the positive z-axis
source = torch.tensor([[0.0, 0.0, 0.0]], device=device)
center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd
center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd

# Use the standard basis for the detector plane
basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)
Expand All @@ -91,10 +110,10 @@ def _initialize_carm(self: Detector):
# Construct equally spaced points along the basis vectors
t = (
torch.arange(-self.height // 2, self.height // 2, device=device) + h_off
) * self.delx
) # * self.delx
s = (
torch.arange(-self.width // 2, self.width // 2, device=device) + w_off
) * self.dely
) # * self.dely
if self.reverse_x_axis:
s = -s
coefs = torch.cartesian_prod(t, s).reshape(-1, 2)
Expand All @@ -105,9 +124,9 @@ def _initialize_carm(self: Detector):
source = source.unsqueeze(0)
target = target.unsqueeze(0)

# Apply principal point offset
target[..., 1] -= self.x0
target[..., 0] -= self.y0
# # Apply principal point offset
# target[..., 1] -= self.x0
# target[..., 0] -= self.y0

if self.n_subsample is not None:
sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]
Expand All @@ -120,9 +139,13 @@ def _initialize_carm(self: Detector):


@patch
def forward(self: Detector, pose: RigidTransform):
def forward(self: Detector, extrinsic: RigidTransform, calibration: RigidTransform):
"""Create source and target points for X-rays to trace through the volume."""
pose = self.reorient.compose(pose)
if calibration is None:
target = self.calibration(self.target)
else:
target = calibration(self.target)
pose = self.reorient.compose(extrinsic)
source = pose(self.source)
target = pose(self.target)
target = pose(target)
return source, target
3 changes: 2 additions & 1 deletion diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ def forward(
*args, # Some batched representation of SE(3)
parameterization: str = None, # Specifies the representation of the rotation
convention: str = None, # If parameterization is Euler angles, specify convention
calibration: RigidTransform = None, # Optional calibration matrix with the detector's intrinsic parameters
mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels
**kwargs, # Passed to the renderer
):
Expand All @@ -122,7 +123,7 @@ def forward(
pose = args[0]
else:
pose = convert(*args, parameterization=parameterization, convention=convention)
source, target = self.detector(pose)
source, target = self.detector(pose, calibration)

# Render the DRR
kwargs["mask"] = self.mask if mask_to_channels else None
Expand Down
3 changes: 2 additions & 1 deletion notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@
" *args, # Some batched representation of SE(3)\n",
" parameterization: str = None, # Specifies the representation of the rotation\n",
" convention: str = None, # If parameterization is Euler angles, specify convention\n",
" calibration: RigidTransform = None, # Optional calibration matrix with the detector's intrinsic parameters\n",
" mask_to_channels: bool = False, # If True, structures from the CT mask are rendered in separate channels\n",
" **kwargs, # Passed to the renderer\n",
"):\n",
Expand All @@ -239,7 +240,7 @@
" pose = args[0]\n",
" else:\n",
" pose = convert(*args, parameterization=parameterization, convention=convention)\n",
" source, target = self.detector(pose)\n",
" source, target = self.detector(pose, calibration)\n",
"\n",
" # Render the DRR\n",
" kwargs[\"mask\"] = self.mask if mask_to_channels else None\n",
Expand Down
61 changes: 42 additions & 19 deletions notebooks/api/02_detector.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,13 @@
" reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis\n",
" ):\n",
" super().__init__()\n",
" self.sdd = sdd\n",
" # self.sdd = sdd\n",
" self.height = height\n",
" self.width = width\n",
" self.delx = delx\n",
" self.dely = dely\n",
" self.x0 = x0\n",
" self.y0 = y0\n",
" # self.delx = delx\n",
" # self.dely = dely\n",
" # self.x0 = x0\n",
" # self.y0 = y0\n",
" self.n_subsample = n_subsample\n",
" if self.n_subsample is not None:\n",
" self.subsamples = []\n",
Expand All @@ -107,20 +107,39 @@
" # Create a pose to reorient the scanner\n",
" self.register_buffer(\"_reorient\", reorient)\n",
"\n",
" # Create a calibration matrix that holds the detector's intrinsic parameters\n",
" self.register_buffer(\n",
" \"_calibration\",\n",
" torch.tensor(\n",
" [\n",
" [delx, 0, 0, x0],\n",
" [0, dely, 0, y0],\n",
" [0, 0, sdd, 0],\n",
" [0, 0, 0, 1],\n",
" ]\n",
" ),\n",
" )\n",
"\n",
" @property\n",
" def reorient(self):\n",
" return RigidTransform(self._reorient)\n",
"\n",
" @property\n",
" def calibration(self):\n",
" \"\"\"A 4x4 matrix that rescales the detector plane to world coordinates.\"\"\"\n",
" return RigidTransform(self._calibration)\n",
"\n",
" @property\n",
" def intrinsic(self):\n",
" \"\"\"The 3x3 intrinsic matrix.\"\"\"\n",
" return make_intrinsic_matrix(\n",
" self.sdd,\n",
" self.delx,\n",
" self.dely,\n",
" self._calibration[2, 2].item(),\n",
" self._calibration[0, 0].item(),\n",
" self._calibration[1, 1].item(),\n",
" self.height,\n",
" self.width,\n",
" self.x0,\n",
" self.y0,\n",
" self._calibration[0, -1].item(),\n",
" self._calibration[1, -1].item(),\n",
" ).to(self.source)"
]
},
Expand All @@ -142,7 +161,7 @@
"\n",
" # Initialize the source at the origin and the center of the detector plane on the positive z-axis\n",
" source = torch.tensor([[0.0, 0.0, 0.0]], device=device)\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device) * self.sdd\n",
" center = torch.tensor([[0.0, 0.0, 1.0]], device=device) # * self.sdd\n",
"\n",
" # Use the standard basis for the detector plane\n",
" basis = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0]], device=device)\n",
Expand All @@ -154,10 +173,10 @@
" # Construct equally spaced points along the basis vectors\n",
" t = (\n",
" torch.arange(-self.height // 2, self.height // 2, device=device) + h_off\n",
" ) * self.delx\n",
" ) # * self.delx\n",
" s = (\n",
" torch.arange(-self.width // 2, self.width // 2, device=device) + w_off\n",
" ) * self.dely\n",
" ) # * self.dely\n",
" if self.reverse_x_axis:\n",
" s = -s\n",
" coefs = torch.cartesian_prod(t, s).reshape(-1, 2)\n",
Expand All @@ -168,9 +187,9 @@
" source = source.unsqueeze(0)\n",
" target = target.unsqueeze(0)\n",
"\n",
" # Apply principal point offset\n",
" target[..., 1] -= self.x0\n",
" target[..., 0] -= self.y0\n",
" # # Apply principal point offset\n",
" # target[..., 1] -= self.x0\n",
" # target[..., 0] -= self.y0\n",
"\n",
" if self.n_subsample is not None:\n",
" sample = torch.randperm(self.height * self.width)[: int(self.n_subsample)]\n",
Expand All @@ -191,11 +210,15 @@
"\n",
"\n",
"@patch\n",
"def forward(self: Detector, pose: RigidTransform):\n",
"def forward(self: Detector, extrinsic: RigidTransform, calibration: RigidTransform):\n",
" \"\"\"Create source and target points for X-rays to trace through the volume.\"\"\"\n",
" pose = self.reorient.compose(pose)\n",
" if calibration is None:\n",
" target = self.calibration(self.target)\n",
" else:\n",
" target = calibration(self.target)\n",
" pose = self.reorient.compose(extrinsic)\n",
" source = pose(self.source)\n",
" target = pose(self.target)\n",
" target = pose(target)\n",
" return source, target"
]
},
Expand Down
2 changes: 1 addition & 1 deletion notebooks/api/03_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@
" mask = torch.any(\n",
" torch.stack([subject.mask.data.squeeze() == idx for idx in labels]), dim=0\n",
" )\n",
" subject.density = subject.density * mask\n",
" subject.density.data = subject.density.data * mask\n",
"\n",
" return subject"
]
Expand Down
22 changes: 11 additions & 11 deletions notebooks/tutorials/introduction.ipynb

Large diffs are not rendered by default.

0 comments on commit 53a149a

Please sign in to comment.