Skip to content

Commit

Permalink
Implement alpha filtering in Trilinear (#281)
Browse files Browse the repository at this point in the history
* Implement separate alpha filtering function

* Update tutorials
  • Loading branch information
eigenvivek authored Jun 16, 2024
1 parent e26bb37 commit 9bda058
Show file tree
Hide file tree
Showing 7 changed files with 154 additions and 91 deletions.
2 changes: 2 additions & 0 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@
'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.dims': ('api/renderers.html#trilinear.dims', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.forward': ('api/renderers.html#trilinear.forward', 'diffdrr/renderers.py'),
'diffdrr.renderers._filter_intersections_outside_volume': ( 'api/renderers.html#_filter_intersections_outside_volume',
'diffdrr/renderers.py'),
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py'),
Expand Down
50 changes: 31 additions & 19 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,24 +114,31 @@ def _get_alphas(

# Sort the intersections
alphas = torch.sort(alphas, dim=-1).values

# Remove interesections that are outside of the volume for all rays
if filter_intersections_outside_volume:
alphamin, alphamax = _get_alpha_minmax(
source, target, origin, spacing, dims, eps
alphas = _filter_intersections_outside_volume(
alphas, source, target, origin, spacing, dims, eps
)
good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)
alphas = alphas[..., good_idxs.any(dim=[0, 1])]

return alphas


def _filter_intersections_outside_volume(
alphas, source, target, origin, spacing, dims, eps
):
"""Remove interesections that are outside of the volume for all rays."""
alphamin, alphamax = _get_alpha_minmax(source, target, origin, spacing, dims, eps)
good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)
alphas = alphas[..., good_idxs.any(dim=[0, 1])]
return alphas


def _get_alpha_minmax(source, target, origin, spacing, dims, eps):
"""Calculate the first and last intersections of each ray with the volume."""
sdd = target - source + eps

planes = torch.zeros(3).to(source)
planes = torch.zeros(3).to(source) - 0.5
alpha0 = (planes * spacing + origin - source) / sdd
planes = (dims - 1).to(source)
planes = dims.to(source) - 0.5
alpha1 = (planes * spacing + origin - source) / sdd
alphas = torch.stack([alpha0, alpha1]).to(source)

Expand Down Expand Up @@ -177,13 +184,15 @@ def __init__(
self,
near=0.0,
far=1.0,
mode="bilinear",
eps=1e-8,
mode: str = "bilinear", # Interpolation mode for grid_sample
filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections
eps: float = 1e-8, # Small constant to avoid div by zero errors
):
super().__init__()
self.near = near
self.far = far
self.mode = mode
self.filter_intersections_outside_volume = filter_intersections_outside_volume
self.eps = eps

def dims(self, volume):
Expand All @@ -200,16 +209,18 @@ def forward(
align_corners=True,
mask=None,
):
# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1).unsqueeze(1)

# Sample points along the rays and rescale to [-1, 1]
alphas = torch.linspace(self.near, self.far, n_points).to(volume)
alphas = alphas[None, None, :]

# Render the DRR
# Sample points along the rays
dims = self.dims(volume)
alphas = torch.linspace(self.near, self.far, n_points)[None, None].to(volume)
if self.filter_intersections_outside_volume:
alphas = _filter_intersections_outside_volume(
alphas, source, target, origin, spacing, dims, self.eps
)

# Get the XYZ coordinate of each alpha, normalized for grid_sample
xyzs = _get_xyzs(alphas, source, target, origin, spacing, dims, self.eps)

# Sample the volume with trilinear interpolation
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)

# Handle optional masking
Expand All @@ -227,6 +238,7 @@ def forward(
.scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))
)

# Multiply by raylength
# Multiply by raylength and return the drr
raylength = (target - source + self.eps).norm(dim=-1).unsqueeze(1)
img *= raylength / n_points
return img
50 changes: 32 additions & 18 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -211,22 +211,31 @@
"\n",
" # Sort the intersections\n",
" alphas = torch.sort(alphas, dim=-1).values\n",
"\n",
" # Remove interesections that are outside of the volume for all rays\n",
" if filter_intersections_outside_volume:\n",
" alphamin, alphamax = _get_alpha_minmax(source, target, origin, spacing, dims, eps)\n",
" good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)\n",
" alphas = alphas[..., good_idxs.any(dim=[0, 1])]\n",
" alphas = _filter_intersections_outside_volume(\n",
" alphas, source, target, origin, spacing, dims, eps\n",
" )\n",
" \n",
" return alphas\n",
"\n",
"\n",
"def _filter_intersections_outside_volume(\n",
" alphas, source, target, origin, spacing, dims, eps\n",
"):\n",
" \"\"\"Remove interesections that are outside of the volume for all rays.\"\"\"\n",
" alphamin, alphamax = _get_alpha_minmax(source, target, origin, spacing, dims, eps)\n",
" good_idxs = torch.logical_and(alphamin <= alphas, alphas <= alphamax)\n",
" alphas = alphas[..., good_idxs.any(dim=[0, 1])]\n",
" return alphas\n",
"\n",
"\n",
"def _get_alpha_minmax(source, target, origin, spacing, dims, eps):\n",
" \"\"\"Calculate the first and last intersections of each ray with the volume.\"\"\"\n",
" sdd = target - source + eps\n",
" \n",
" planes = torch.zeros(3).to(source)\n",
" planes = torch.zeros(3).to(source) - 0.5\n",
" alpha0 = (planes * spacing + origin - source) / sdd\n",
" planes = (dims - 1).to(source)\n",
" planes = dims.to(source) - 0.5\n",
" alpha1 = (planes * spacing + origin - source) / sdd\n",
" alphas = torch.stack([alpha0, alpha1]).to(source)\n",
"\n",
Expand Down Expand Up @@ -294,13 +303,15 @@
" self,\n",
" near=0.0,\n",
" far=1.0,\n",
" mode=\"bilinear\",\n",
" eps=1e-8,\n",
" mode: str = \"bilinear\", # Interpolation mode for grid_sample\n",
" filter_intersections_outside_volume: bool = True, # Use alphamin/max to filter the intersections\n",
" eps: float = 1e-8, # Small constant to avoid div by zero errors\n",
" ):\n",
" super().__init__()\n",
" self.near = near\n",
" self.far = far\n",
" self.mode = mode\n",
" self.filter_intersections_outside_volume = filter_intersections_outside_volume\n",
" self.eps = eps\n",
"\n",
" def dims(self, volume):\n",
Expand All @@ -317,16 +328,18 @@
" align_corners=True,\n",
" mask=None,\n",
" ):\n",
" # Get the raylength and reshape sources\n",
" raylength = (source - target + self.eps).norm(dim=-1).unsqueeze(1)\n",
"\n",
" # Sample points along the rays and rescale to [-1, 1]\n",
" alphas = torch.linspace(self.near, self.far, n_points).to(volume)\n",
" alphas = alphas[None, None, :]\n",
"\n",
" # Render the DRR\n",
" # Sample points along the rays\n",
" dims = self.dims(volume)\n",
" alphas = torch.linspace(self.near, self.far, n_points)[None, None].to(volume)\n",
" if self.filter_intersections_outside_volume:\n",
" alphas = _filter_intersections_outside_volume(\n",
" alphas, source, target, origin, spacing, dims, self.eps\n",
" )\n",
"\n",
" # Get the XYZ coordinate of each alpha, normalized for grid_sample\n",
" xyzs = _get_xyzs(alphas, source, target, origin, spacing, dims, self.eps)\n",
"\n",
" # Sample the volume with trilinear interpolation\n",
" img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)\n",
"\n",
" # Handle optional masking\n",
Expand All @@ -344,7 +357,8 @@
" .scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))\n",
" )\n",
"\n",
" # Multiply by raylength\n",
" # Multiply by raylength and return the drr\n",
" raylength = (target - source + self.eps).norm(dim=-1).unsqueeze(1)\n",
" img *= raylength / n_points\n",
" return img"
]
Expand Down
18 changes: 9 additions & 9 deletions notebooks/tutorials/introduction.ipynb

Large diffs are not rendered by default.

42 changes: 21 additions & 21 deletions notebooks/tutorials/reconstruction.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/tutorials/registration.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,7 @@
"id": "ebf4a6da-a6d2-421a-bb36-c8157716f776",
"metadata": {},
"source": [
"L-BFGS with line search converges so quickly that a GIF with ~30 FPS is imperceptable. Here's the same GIF but at 1 FPS."
"L-BFGS with line search converges so quickly that a GIF with ~30 FPS is imperceptible. Here's the same GIF but at 1 FPS."
]
},
{
Expand Down
81 changes: 58 additions & 23 deletions notebooks/tutorials/trilinear.ipynb

Large diffs are not rendered by default.

0 comments on commit 9bda058

Please sign in to comment.