Skip to content

Commit

Permalink
Improve docstrings (#356)
Browse files Browse the repository at this point in the history
* Add a combined function to plot images and mask

* Fix circular imports

* Allow setting of reverse_x_axis and n_subsamples

* Make a deepcopy of the DRR module

* Expose the DRR's device and dtype

* Make intrinsic matrix calculation more explicit

* Simplify the Detector

* Implement a custom PinholeCamera class

* Add the projection matrix and pose
  • Loading branch information
eigenvivek authored Dec 13, 2024
1 parent 6988bfd commit e41c365
Show file tree
Hide file tree
Showing 9 changed files with 403 additions and 277 deletions.
21 changes: 16 additions & 5 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,19 @@
'diffdrr.detector.Detector.reorient': ('api/detector.html#detector.reorient', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.sdd': ('api/detector.html#detector.sdd', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.x0': ('api/detector.html#detector.x0', 'diffdrr/detector.py'),
'diffdrr.detector.Detector.y0': ('api/detector.html#detector.y0', 'diffdrr/detector.py')},
'diffdrr.detector.Detector.y0': ('api/detector.html#detector.y0', 'diffdrr/detector.py'),
'diffdrr.detector.get_focal_length': ('api/detector.html#get_focal_length', 'diffdrr/detector.py'),
'diffdrr.detector.get_principal_point': ('api/detector.html#get_principal_point', 'diffdrr/detector.py'),
'diffdrr.detector.make_intrinsic_matrix': ( 'api/detector.html#make_intrinsic_matrix',
'diffdrr/detector.py'),
'diffdrr.detector.parse_intrinsic_matrix': ( 'api/detector.html#parse_intrinsic_matrix',
'diffdrr/detector.py')},
'diffdrr.drr': { 'diffdrr.drr.DRR': ('api/drr.html#drr', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.__init__': ('api/drr.html#drr.__init__', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.affine': ('api/drr.html#drr.affine', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.affine_inverse': ('api/drr.html#drr.affine_inverse', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.device': ('api/drr.html#drr.device', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.dtype': ('api/drr.html#drr.dtype', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.forward': ('api/drr.html#drr.forward', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.inverse_projection': ('api/drr.html#drr.inverse_projection', 'diffdrr/drr.py'),
'diffdrr.drr.DRR.n_patches': ('api/drr.html#drr.n_patches', 'diffdrr/drr.py'),
Expand Down Expand Up @@ -167,11 +175,12 @@
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_xyzs': ('api/renderers.html#_get_xyzs', 'diffdrr/renderers.py'),
'diffdrr.renderers.reduce': ('api/renderers.html#reduce', 'diffdrr/renderers.py')},
'diffdrr.utils': { 'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
'diffdrr.utils': { 'diffdrr.utils.PinholeCamera': ('api/utils.html#pinholecamera', 'diffdrr/utils.py'),
'diffdrr.utils.PinholeCamera.__init__': ('api/utils.html#pinholecamera.__init__', 'diffdrr/utils.py'),
'diffdrr.utils.PinholeCamera.center': ('api/utils.html#pinholecamera.center', 'diffdrr/utils.py'),
'diffdrr.utils.PinholeCamera.pose': ('api/utils.html#pinholecamera.pose', 'diffdrr/utils.py'),
'diffdrr.utils.PinholeCamera.projmat': ('api/utils.html#pinholecamera.projmat', 'diffdrr/utils.py'),
'diffdrr.utils.get_pinhole_camera': ('api/utils.html#get_pinhole_camera', 'diffdrr/utils.py'),
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
'diffdrr.utils.make_intrinsic_matrix': ('api/utils.html#make_intrinsic_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.parse_intrinsic_matrix': ('api/utils.html#parse_intrinsic_matrix', 'diffdrr/utils.py'),
'diffdrr.utils.resample': ('api/utils.html#resample', 'diffdrr/utils.py')},
'diffdrr.visualization': { 'diffdrr.visualization._make_camera_frustum_mesh': ( 'api/visualization.html#_make_camera_frustum_mesh',
'diffdrr/visualization.py'),
Expand All @@ -184,6 +193,8 @@
'diffdrr.visualization.labelmap_to_mesh': ( 'api/visualization.html#labelmap_to_mesh',
'diffdrr/visualization.py'),
'diffdrr.visualization.plot_drr': ('api/visualization.html#plot_drr', 'diffdrr/visualization.py'),
'diffdrr.visualization.plot_img_and_mask': ( 'api/visualization.html#plot_img_and_mask',
'diffdrr/visualization.py'),
'diffdrr.visualization.plot_mask': ('api/visualization.html#plot_mask', 'diffdrr/visualization.py'),
'diffdrr.visualization.visualize_scene': ( 'api/visualization.html#visualize_scene',
'diffdrr/visualization.py')}}}
77 changes: 58 additions & 19 deletions diffdrr/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,26 +8,25 @@
from torch.nn.functional import normalize

# %% auto 0
__all__ = ['Detector']
__all__ = ['Detector', 'get_focal_length', 'get_principal_point', 'parse_intrinsic_matrix', 'make_intrinsic_matrix']

# %% ../notebooks/api/02_detector.ipynb 5
from .pose import RigidTransform
from .utils import make_intrinsic_matrix


class Detector(torch.nn.Module):
"""Construct a 6 DoF X-ray detector system. This model is based on a C-Arm."""

def __init__(
self,
sdd: float, # Source-to-detector distance (i.e., focal length)
height: int, # Height of the X-ray detector
width: int, # Width of the X-ray detector
delx: float, # Pixel spacing in the X-direction
dely: float, # Pixel spacing in the Y-direction
x0: float, # Principal point X-offset
y0: float, # Principal point Y-offset
reorient: torch.tensor, # Frame-of-reference change matrix
sdd: float, # Source-to-detector distance (in units length)
height: int, # Y-direction length (in units pixels)
width: int, # X-direction length (in units pixels)
delx: float, # X-direction spacing (in units length / pixel)
dely: float, # Y-direction spacing (in units length / pixel)
x0: float, # Principal point x-coordinate (in units length)
y0: float, # Principal point y-coordinate (in units length)
reorient: torch.Tensor, # Frame-of-reference change matrix
n_subsample: int | None = None, # Number of target points to randomly sample
reverse_x_axis: bool = False, # If pose includes reflection (in E(3) not SE(3)), reverse x-axis
):
Expand Down Expand Up @@ -92,15 +91,7 @@ def calibration(self):
@property
def intrinsic(self):
"""The 3x3 intrinsic matrix."""
return make_intrinsic_matrix(
self.sdd,
self.delx,
self.dely,
self.width,
self.height,
self.y0,
self.x0,
).to(self.source)
return make_intrinsic_matrix(self).to(self.source)

# %% ../notebooks/api/02_detector.ipynb 6
@patch
Expand Down Expand Up @@ -157,3 +148,51 @@ def forward(self: Detector, extrinsic: RigidTransform, calibration: RigidTransfo
source = pose(self.source)
target = pose(target)
return source, target

# %% ../notebooks/api/02_detector.ipynb 9
def get_focal_length(
intrinsic, # Intrinsic matrix (3 x 3 tensor)
delx: float, # X-direction spacing (in units length)
dely: float, # Y-direction spacing (in units length)
) -> float: # Focal length (in units length)
fx = intrinsic[0, 0]
fy = intrinsic[1, 1]
return abs((fx * delx) + (fy * dely)).item() / 2.0

# %% ../notebooks/api/02_detector.ipynb 10
def get_principal_point(
intrinsic, # Intrinsic matrix (3 x 3 tensor)
height: int, # Y-direction length (in units pixels)
width: int, # X-direction length (in units pixels)
delx: float, # X-direction spacing (in units length)
dely: float, # Y-direction spacing (in units length)
):
x0 = delx * (intrinsic[0, 2] - width / 2)
y0 = dely * (intrinsic[1, 2] - height / 2)
return x0.item(), y0.item()

# %% ../notebooks/api/02_detector.ipynb 11
def parse_intrinsic_matrix(
intrinsic, # Intrinsic matrix (3 x 3 tensor)
height: int, # Y-direction length (in units pixels)
width: int, # X-direction length (in units pixels)
delx: float, # X-direction spacing (in units length)
dely: float, # Y-direction spacing (in units length)
):
focal_length = get_focal_length(intrinsic, delx, dely)
x0, y0 = get_principal_point(intrinsic, height, width, delx, dely)
return focal_length, x0, y0

# %% ../notebooks/api/02_detector.ipynb 12
def make_intrinsic_matrix(detector: Detector):
fx = detector.sdd / detector.delx
fy = detector.sdd / detector.dely
u0 = detector.x0 / detector.delx + detector.width / 2
v0 = detector.y0 / detector.dely + detector.height / 2
return torch.tensor(
[
[fx, 0.0, u0],
[0.0, fy, v0],
[0.0, 0.0, 1.0],
]
)
16 changes: 13 additions & 3 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,14 @@ def affine_inverse(self):
def n_patches(self):
return (self.detector.height * self.detector.width) // (self.patch_size**2)

@property
def device(self):
return self.density.device

@property
def dtype(self):
return self.density.dtype

# %% ../notebooks/api/00_drr.ipynb 8
def reshape_subsampled_drr(img: torch.Tensor, detector: Detector, batch_size: int):
n_points = detector.height * detector.width
Expand Down Expand Up @@ -212,6 +220,8 @@ def set_intrinsics_(
dely: float = None,
x0: float = None,
y0: float = None,
n_subsample: int = None,
reverse_x_axis: bool = None,
):
"""Set new intrinsic parameters (inplace)."""
self.detector = Detector(
Expand All @@ -222,9 +232,9 @@ def set_intrinsics_(
dely if dely is not None else self.detector.dely,
x0 if x0 is not None else self.detector.x0,
y0 if y0 is not None else self.detector.y0,
n_subsample=self.detector.n_subsample,
reverse_x_axis=self.detector.reverse_x_axis,
reorient=self.subject.reorient,
self.subject.reorient,
n_subsample if n_subsample is not None else self.detector.n_subsample,
reverse_x_axis if reverse_x_axis is not None else self.detector.reverse_x_axis,
).to(self.density)

# %% ../notebooks/api/00_drr.ipynb 12
Expand Down
Loading

0 comments on commit e41c365

Please sign in to comment.