Skip to content

Commit

Permalink
siddon revised (#272)
Browse files Browse the repository at this point in the history
* implemented grid_sample and a bit of memory managment

* formatted using black

* Remove obsolete nan operations

* Add more comments

* Ensure mask is torch.float32

* Fix argument error

* Make argument names consistent with PyTorch

* Update README

* Update metrics tutorial

* Fix grid_sample indexing

* Update README.md

* Rerun tutorials with new Siddon's implementation

* Reuse inplace ops and remove unecessary gradient calculation for grid_sample

* Bump version

---------

Co-authored-by: Vivek Gopalakrishnan <[email protected]>
  • Loading branch information
hossein-momeni and eigenvivek authored Jun 14, 2024
1 parent 858bb1a commit e04e1d8
Show file tree
Hide file tree
Showing 14 changed files with 325 additions and 355 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ plt.show()

On a single NVIDIA RTX 2080 Ti GPU, producing such an image takes

37.9 ms ± 19.6 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
29.5 ms ± 45.2 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)

The full example is available at
[`introduction.ipynb`](https://vivekg.dev/DiffDRR/tutorials/introduction.html).
Expand Down
2 changes: 1 addition & 1 deletion diffdrr/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.4.0"
__version__ = "0.4.1"
6 changes: 2 additions & 4 deletions diffdrr/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,14 @@
'diffdrr.renderers.Siddon.__init__': ('api/renderers.html#siddon.__init__', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.dims': ('api/renderers.html#siddon.dims', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.forward': ('api/renderers.html#siddon.forward', 'diffdrr/renderers.py'),
'diffdrr.renderers.Siddon.maxidx': ('api/renderers.html#siddon.maxidx', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear': ('api/renderers.html#trilinear', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.__init__': ( 'api/renderers.html#trilinear.__init__',
'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.dims': ('api/renderers.html#trilinear.dims', 'diffdrr/renderers.py'),
'diffdrr.renderers.Trilinear.forward': ('api/renderers.html#trilinear.forward', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_alpha_minmax': ('api/renderers.html#_get_alpha_minmax', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_alphas': ('api/renderers.html#_get_alphas', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_index': ('api/renderers.html#_get_index', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py')},
'diffdrr.renderers._get_voxel': ('api/renderers.html#_get_voxel', 'diffdrr/renderers.py'),
'diffdrr.renderers._get_xyzs': ('api/renderers.html#_get_xyzs', 'diffdrr/renderers.py')},
'diffdrr.utils': { 'diffdrr.utils.get_focal_length': ('api/utils.html#get_focal_length', 'diffdrr/utils.py'),
'diffdrr.utils.get_principal_point': ('api/utils.html#get_principal_point', 'diffdrr/utils.py'),
'diffdrr.utils.make_intrinsic_matrix': ('api/utils.html#make_intrinsic_matrix', 'diffdrr/utils.py'),
Expand Down
200 changes: 89 additions & 111 deletions diffdrr/renderers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,134 +5,128 @@

# %% ../notebooks/api/01_renderers.ipynb 3
import torch
from torch.nn.functional import grid_sample

# %% ../notebooks/api/01_renderers.ipynb 7
class Siddon(torch.nn.Module):
"""Differentiable X-ray renderer implemented with Siddon's method for exact raytracing."""

def __init__(self, eps=1e-8):
def __init__(
self,
mode="nearest",
eps=1e-8,
):
super().__init__()
self.mode = mode
self.eps = eps

def dims(self, volume):
return torch.tensor(volume.shape).to(volume) + 1

def maxidx(self, volume):
return volume.numel() - 1
return torch.tensor(volume.shape).to(volume)

def forward(self, volume, origin, spacing, source, target, mask=None):
def forward(
self,
volume,
origin,
spacing,
source,
target,
align_corners=True,
mask=None,
):
dims = self.dims(volume)
maxidx = self.maxidx(volume)
origin = origin.to(torch.float64)
origin = origin.to(
torch.float64
) # Somehow dramatically improves rendering quality (https://github.com/eigenvivek/DiffDRR/issues/202)

# Calculate the intersections of each ray with the planes comprising the CT volume
alphas = _get_alphas(source, target, origin, spacing, dims, self.eps)

# Calculate the midpoint of every pair of adjacent intersections
# These midpoints lie exclusively in a single voxel
alphamid = (alphas[..., 0:-1] + alphas[..., 1:]) / 2
voxels, idxs = _get_voxel(
alphamid, source, target, volume, origin, spacing, dims, maxidx, self.eps
)

# Step length for alphas out of range will be nan
# These nans cancel out voxels convereted to 0 index
step_length = torch.diff(alphas, dim=-1)
weighted_voxels = voxels * step_length
# Get the XYZ coordinate of each midpoint (normalized to [-1, +1]^3)
xyzs = _get_xyzs(alphamid, source, target, origin, spacing, dims, self.eps)

# Use torch.nn.functional.grid_sample to lookup the values of each intersected voxel
with torch.no_grad():
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)

# Weight each intersected voxel by the length of the ray's intersection with the voxel
intersection_length = torch.diff(alphas, dim=-1)
img = img * intersection_length

# Handle optional masking
if mask is None:
img = torch.nansum(weighted_voxels, dim=-1)
img = img.sum(dim=-1)
img = img.unsqueeze(1)
else:
# Thanks to @Ivan for the clutch assist w/ pytorch tensor ops
# https://stackoverflow.com/questions/78323859/broadcast-pytorch-array-across-channels-based-on-another-array/78324614#78324614
channels = torch.take(mask, idxs) # B D N
weighted_voxels = weighted_voxels.nan_to_num()
B, D, N = weighted_voxels.shape
B, D, _ = img.shape
C = mask.max().item() + 1
channels = _get_voxel(
mask.to(torch.float32), xyzs, align_corners=align_corners
).long()
img = (
torch.zeros(B, C, D)
.to(volume)
.scatter_add_(
1, channels.transpose(-1, -2), weighted_voxels.transpose(-1, -2)
)
.scatter_add_(1, channels.transpose(-1, -2), img.transpose(-1, -2))
)

# Finish rendering the DRR
# Multiply by ray length such that the proportion of attenuated energy is unitless
raylength = (target - source + self.eps).norm(dim=-1)
img *= raylength.unsqueeze(1)
return img

# %% ../notebooks/api/01_renderers.ipynb 8
def _get_alphas(source, target, origin, spacing, dims, eps):
# Get the CT sizing and spacing parameters
alphax = torch.arange(dims[0]).to(source) * spacing[0] + origin[0]
alphay = torch.arange(dims[1]).to(source) * spacing[1] + origin[1]
alphaz = torch.arange(dims[2]).to(source) * spacing[2] + origin[2]

# Get the alpha at each plane intersection
sx, sy, sz = source[..., 0], source[..., 1], source[..., 2]
alphax = alphax.expand(len(source), 1, -1) - sx.unsqueeze(-1)
alphay = alphay.expand(len(source), 1, -1) - sy.unsqueeze(-1)
alphaz = alphaz.expand(len(source), 1, -1) - sz.unsqueeze(-1)

sdd = target - source + eps
alphax = alphax / sdd[..., 0].unsqueeze(-1)
alphay = alphay / sdd[..., 1].unsqueeze(-1)
alphaz = alphaz / sdd[..., 2].unsqueeze(-1)
"""Calculates the parametric intersections of each ray with the planes of the CT volume."""
# Parameterize the parallel XYZ planes that comprise the CT volumes
alphax = (torch.arange(dims[0] + 1).to(source) - 0.5) * spacing[0] + origin[0]
alphay = (torch.arange(dims[1] + 1).to(source) - 0.5) * spacing[1] + origin[1]
alphaz = (torch.arange(dims[2] + 1).to(source) - 0.5) * spacing[2] + origin[2]

# Calculate the parametric intersection of each ray with every plane
sx, sy, sz = source[..., 0:1], source[..., 1:2], source[..., 2:3]
tx, ty, tz = target[..., 0:1], target[..., 1:2], target[..., 2:3]
alphax = (alphax.expand(len(source), 1, -1) - sx) / (tx - sx + eps)
alphay = (alphay.expand(len(source), 1, -1) - sy) / (ty - sy + eps)
alphaz = (alphaz.expand(len(source), 1, -1) - sz) / (tz - sz + eps)
alphas = torch.cat([alphax, alphay, alphaz], dim=-1)

# Get the alphas within the range [alphamin, alphamax]
alphamin, alphamax = _get_alpha_minmax(sdd, source, target, origin, spacing, dims)
good_idxs = torch.logical_and(alphas >= alphamin, alphas <= alphamax)
alphas[~good_idxs] = torch.nan

# Sort the alphas by ray, putting nans at the end of the list
# Sort the intersections
alphas = torch.sort(alphas, dim=-1).values

# Drop indices where alphas for all rays are nan
alphas = alphas[..., ~alphas.isnan().all(dim=0).all(dim=0)]

return alphas


def _get_alpha_minmax(sdd, source, target, origin, spacing, dims):
planes = torch.zeros(3).to(source)
alpha0 = (planes * spacing + origin - source) / sdd
planes = (dims - 1).to(source)
alpha1 = (planes * spacing + origin - source) / sdd
alphas = torch.stack([alpha0, alpha1]).to(source)

alphamin = alphas.min(dim=0).values.max(dim=-1).values.unsqueeze(-1)
alphamax = alphas.max(dim=0).values.min(dim=-1).values.unsqueeze(-1)

alphamin = torch.where(alphamin < 0.0, 0.0, alphamin)
alphamax = torch.where(alphamax > 1.0, 1.0, alphamax)
return alphamin, alphamax


def _get_voxel(alpha, source, target, volume, origin, spacing, dims, maxidx, eps):
idxs = _get_index(alpha, source, target, origin, spacing, dims, maxidx, eps)
return torch.take(volume, idxs), idxs


def _get_index(alpha, source, target, origin, spacing, dims, maxidx, eps):
sdd = target - source + eps
idxs = source.unsqueeze(2) + alpha.unsqueeze(-1) * sdd.unsqueeze(2)
idxs = (idxs - origin) / spacing
idxs = idxs.floor()
# Conversion to long makes nan->-inf, so temporarily replace them with 0
# This is cancelled out later by multiplication by nan step_length
idxs = (
idxs[..., 0] * (dims[1] - 1) * (dims[2] - 1)
+ idxs[..., 1] * (dims[2] - 1)
+ idxs[..., 2]
).long()
idxs[idxs < 0] = 0
idxs[idxs > maxidx] = maxidx
return idxs
def _get_xyzs(alpha, source, target, origin, spacing, dims, eps):
"""Given a set of rays and parametric coordinates, calculates the XYZ coordinates."""
# Get the world coordinates of every midpoint
xyzs = (
source.unsqueeze(-2)
+ alpha.unsqueeze(-1) * (target - source + eps).unsqueeze(2)
- origin.to(torch.float32)
)

# Normalize coordinates to be in [-1, +1]
# Use inplace operations to minimize memory overhead
xyzs.mul_(2).div_(spacing * (dims - 1)).sub_(1)
return xyzs.unsqueeze(1)


def _get_voxel(volume, xyzs, mode="nearest", align_corners=True):
"""Wraps torch.nn.functional.grid_sample to sample a volume at XYZ coordinates."""
batch_size = len(xyzs)
voxels = grid_sample(
input=volume.permute(2, 1, 0)[None, None].expand(batch_size, -1, -1, -1, -1),
grid=xyzs,
mode=mode,
align_corners=align_corners,
)[:, 0, 0]
return voxels

# %% ../notebooks/api/01_renderers.ipynb 10
from torch.nn.functional import grid_sample


class Trilinear(torch.nn.Module):
"""Differentiable X-ray renderer implemented with trilinear interpolation."""

Expand All @@ -150,7 +144,7 @@ def __init__(
self.eps = eps

def dims(self, volume):
return torch.tensor(volume.shape).to(volume) - 1
return torch.tensor(volume.shape).to(volume)

def forward(
self,
Expand All @@ -165,41 +159,25 @@ def forward(
):
# Get the raylength and reshape sources
raylength = (source - target + self.eps).norm(dim=-1).unsqueeze(1)
source = source[:, None, :, None, :] - origin
target = target[:, None, :, None, :] - origin

# Sample points along the rays and rescale to [-1, 1]
alphas = torch.linspace(self.near, self.far, n_points).to(volume)
alphas = alphas[None, None, None, :, None]
rays = source + alphas * (target - source)
rays = 2 * rays / (spacing * self.dims(volume)) - 1

# Reorder array to match torch conventions
volume = volume.permute(2, 1, 0)
if mask is not None:
mask = mask.permute(2, 1, 0)
alphas = alphas[None, None, :]

# Render the DRR
batch_size = len(rays)
img = grid_sample(
volume[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1),
rays,
mode=self.mode,
align_corners=align_corners,
)[:, 0, 0]
dims = self.dims(volume)
xyzs = _get_xyzs(alphas, source, target, origin, spacing, dims, self.eps)
img = _get_voxel(volume, xyzs, self.mode, align_corners=align_corners)

# Handle optional masking
if mask is None:
img = img.sum(dim=-1).unsqueeze(1)
else:
B, D, N = img.shape
B, D, _ = img.shape
C = mask.max().item() + 1
channels = grid_sample(
mask[None, None, :, :, :].expand(batch_size, -1, -1, -1, -1).float(),
rays,
mode="nearest",
align_corners=align_corners,
).long()[:, 0, 0]
channels = _get_voxel(
mask.to(torch.float32), xyzs, align_corners=align_corners
).long()
img = (
torch.zeros(B, C, D)
.to(volume)
Expand Down
Loading

0 comments on commit e04e1d8

Please sign in to comment.