Skip to content

Commit

Permalink
Fix gradients in Siddon's method (#296)
Browse files Browse the repository at this point in the history
* Fix bugs in Siddon's method

* Remove warnings

* Revert default renderer to 'siddon'

* Rerun registration tutorial
  • Loading branch information
eigenvivek authored Jul 3, 2024
1 parent b44ad22 commit 06a76b6
Show file tree
Hide file tree
Showing 6 changed files with 86 additions and 69 deletions.
7 changes: 1 addition & 6 deletions diffdrr/drr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# %% ../notebooks/api/00_drr.ipynb 3
from __future__ import annotations

import warnings

import numpy as np
import torch
import torch.nn as nn
Expand Down Expand Up @@ -39,7 +37,7 @@ def __init__(
reshape: bool = True, # Return DRR with shape (b, 1, h, w)
reverse_x_axis: bool = True, # If True, obey radiologic convention (e.g., heart on right)
patch_size: int | None = None, # Render patches of the DRR in series
renderer: str = "trilinear", # Rendering backend, either "siddon" or "trilinear"
renderer: str = "siddon", # Rendering backend, either "siddon" or "trilinear"
persistent: bool = True, # Set persistent value in `torch.nn.Module.register_buffer`
**renderer_kwargs, # Kwargs for the renderer
):
Expand Down Expand Up @@ -98,9 +96,6 @@ def __init__(
# Initialize the renderer
if renderer == "siddon":
self.renderer = Siddon(**renderer_kwargs)
warnings.warn(
"Gradients from Siddon's method are currently unstable for 2D/3D registration. Use 'trilinear' instead."
)
elif renderer == "trilinear":
self.renderer = Trilinear(**renderer_kwargs)
else:
Expand Down
8 changes: 4 additions & 4 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ def forward(
def _get_alphas(source, target, dims, eps, filter_intersections_outside_volume):
"""Calculates the parametric intersections of each ray with the planes of the CT volume."""
# Parameterize the parallel XYZ planes that comprise the CT volumes
alphax = torch.arange(dims[0] + 1).to(source) - 0.5
alphay = torch.arange(dims[1] + 1).to(source) - 0.5
alphaz = torch.arange(dims[2] + 1).to(source) - 0.5
alphax = torch.arange(dims[0] + 1).to(source)
alphay = torch.arange(dims[1] + 1).to(source)
alphaz = torch.arange(dims[2] + 1).to(source)

# Calculate the parametric intersection of each ray with every plane
sx, sy, sz = source[..., 0:1], source[..., 1:2], source[..., 2:3]
Expand Down Expand Up @@ -121,7 +121,7 @@ def _get_alpha_minmax(source, target, dims, eps):
sdd = target - source + eps

alpha0 = (torch.zeros(3).to(source) - source) / sdd
alpha1 = ((dims - 1).to(source) - source) / sdd
alpha1 = ((dims + 1).to(source) - source) / sdd
alphas = torch.stack([alpha0, alpha1])

alphamin = alphas.min(dim=0).values.max(dim=-1).values.unsqueeze(-1)
Expand Down
5 changes: 1 addition & 4 deletions notebooks/api/00_drr.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
"#| export\n",
"from __future__ import annotations\n",
"\n",
"import warnings\n",
"\n",
"import numpy as np\n",
"import torch\n",
"import torch.nn as nn\n",
Expand Down Expand Up @@ -134,7 +132,7 @@
" reshape: bool = True, # Return DRR with shape (b, 1, h, w)\n",
" reverse_x_axis: bool = True, # If True, obey radiologic convention (e.g., heart on right)\n",
" patch_size: int | None = None, # Render patches of the DRR in series\n",
" renderer: str = \"trilinear\", # Rendering backend, either \"siddon\" or \"trilinear\"\n",
" renderer: str = \"siddon\", # Rendering backend, either \"siddon\" or \"trilinear\"\n",
" persistent: bool = True, # Set persistent value in `torch.nn.Module.register_buffer`\n",
" **renderer_kwargs, # Kwargs for the renderer\n",
" ):\n",
Expand Down Expand Up @@ -193,7 +191,6 @@
" # Initialize the renderer\n",
" if renderer == \"siddon\":\n",
" self.renderer = Siddon(**renderer_kwargs)\n",
" warnings.warn(\"Gradients from Siddon's method are currently unstable for 2D/3D registration. Use 'trilinear' instead.\")\n",
" elif renderer == \"trilinear\":\n",
" self.renderer = Trilinear(**renderer_kwargs)\n",
" else:\n",
Expand Down
8 changes: 4 additions & 4 deletions notebooks/api/01_renderers.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@
"def _get_alphas(source, target, dims, eps, filter_intersections_outside_volume):\n",
" \"\"\"Calculates the parametric intersections of each ray with the planes of the CT volume.\"\"\"\n",
" # Parameterize the parallel XYZ planes that comprise the CT volumes\n",
" alphax = torch.arange(dims[0] + 1).to(source) - 0.5\n",
" alphay = torch.arange(dims[1] + 1).to(source) - 0.5\n",
" alphaz = torch.arange(dims[2] + 1).to(source) - 0.5\n",
" alphax = torch.arange(dims[0] + 1).to(source)\n",
" alphay = torch.arange(dims[1] + 1).to(source)\n",
" alphaz = torch.arange(dims[2] + 1).to(source)\n",
"\n",
" # Calculate the parametric intersection of each ray with every plane\n",
" sx, sy, sz = source[..., 0:1], source[..., 1:2], source[..., 2:3]\n",
Expand Down Expand Up @@ -230,7 +230,7 @@
" sdd = target - source + eps\n",
"\n",
" alpha0 = (torch.zeros(3).to(source) - source) / sdd\n",
" alpha1 = ((dims - 1).to(source) - source) / sdd\n",
" alpha1 = ((dims + 1).to(source) - source) / sdd\n",
" alphas = torch.stack([alpha0, alpha1])\n",
"\n",
" alphamin = alphas.min(dim=0).values.max(dim=-1).values.unsqueeze(-1)\n",
Expand Down
125 changes: 75 additions & 50 deletions notebooks/tutorials/registration.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion notebooks/tutorials/registration_runs.html

Large diffs are not rendered by default.

0 comments on commit 06a76b6

Please sign in to comment.