Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rsg devel #54

Open
wants to merge 21 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 12 additions & 9 deletions src/libtilt/ctf/ctf_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def calculate_ctf(
image_shape: Tuple[int, int],
rfft: bool,
fftshift: bool,
device: torch.device | None = None
):
"""

Expand Down Expand Up @@ -56,20 +57,22 @@ def calculate_ctf(
Whether to apply fftshift on the resulting CTF images.
"""
# to torch.Tensor and unit conversions
defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float))
if bool(rfft) + bool(fftshift) > 1:
raise ValueError("Only one of `rfft` and `fftshift` may be `True`.")
defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float, device=device))
defocus *= 1e4 # micrometers -> angstroms
astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float))
astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float, device=device))
astigmatism *= 1e4 # micrometers -> angstroms
astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float))
astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float, device=device))
astigmatism_angle *= (C.pi / 180) # degrees -> radians
pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size))
voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float))
pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size, device=device))
voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float, device=device))
voltage *= 1e3 # kV -> V
spherical_aberration = torch.atleast_1d(
torch.as_tensor(spherical_aberration, dtype=torch.float)
torch.as_tensor(spherical_aberration, dtype=torch.float, device=device)
)
spherical_aberration *= 1e7 # mm -> angstroms
image_shape = torch.as_tensor(image_shape)
image_shape = torch.as_tensor(image_shape, device=device)

# derived quantities used in CTF calculation
defocus_u = defocus + astigmatism
Expand All @@ -79,10 +82,10 @@ def calculate_ctf(
k2 = C.pi / 2 * spherical_aberration * _lambda ** 3
k3 = torch.tensor(np.deg2rad(phase_shift))
k4 = -b_factor / 4
k5 = np.arctan(amplitude_contrast / np.sqrt(1 - amplitude_contrast ** 2))
k5 = torch.arctan(amplitude_contrast / torch.sqrt(1 - amplitude_contrast ** 2))

# construct 2D frequency grids and rescale cycles / px -> cycles / Å
fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft) # (h, w, 2)
fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft, device=device) # (h, w, 2)
fftfreq_grid = fftfreq_grid / einops.rearrange(pixel_size, 'b -> b 1 1 1')
fftfreq_grid_squared = fftfreq_grid ** 2

Expand Down
11 changes: 7 additions & 4 deletions src/libtilt/fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ def dft_center(
device: torch.device | None = None,
) -> torch.LongTensor:
"""Return the position of the DFT center for a given input shape."""
_rfft_shape = rfft_shape(image_shape)
fft_center = torch.zeros(size=(len(image_shape),), device=device)
image_shape = torch.as_tensor(image_shape).float()
if rfft is True:
image_shape = torch.tensor(rfft_shape(image_shape))
image_shape = torch.tensor(_rfft_shape, device=device)
if fftshifted is True:
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
if rfft is True:
Expand Down Expand Up @@ -438,17 +439,19 @@ def fftfreq_to_dft_coordinates(
coordinates: torch.Tensor
`(..., d)` array of coordinates into a fftshifted DFT.
"""
_image_shape = image_shape
image_shape = torch.as_tensor(
image_shape, device=frequencies.device, dtype=frequencies.dtype
_image_shape, device=frequencies.device, dtype=frequencies.dtype
)
_rfft_shape = rfft_shape(_image_shape)
_rfft_shape = torch.as_tensor(
rfft_shape(image_shape), device=frequencies.device, dtype=frequencies.dtype
_rfft_shape, device=frequencies.device, dtype=frequencies.dtype
)
coordinates = torch.empty_like(frequencies)
coordinates[..., :-1] = frequencies[..., :-1] * image_shape[:-1]
if rfft is True:
coordinates[..., -1] = frequencies[..., -1] * 2 * (_rfft_shape[-1] - 1)
else:
coordinates[..., -1] = frequencies[..., -1] * image_shape[-1]
dc = dft_center(image_shape, rfft=rfft, fftshifted=True, device=frequencies.device)
dc = dft_center(_image_shape, rfft=rfft, fftshifted=True, device=frequencies.device)
return coordinates + dc
3 changes: 2 additions & 1 deletion src/libtilt/interpolation/interpolate_dft_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def sample_dft_3d(
samples = torch.view_as_complex(samples.contiguous()) # (b, )

# pack data back up and return
[samples] = einops.unpack(samples, pattern='*', packed_shapes=ps)
# [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps)
samples = samples.reshape(*ps) # replaces commented line above, for performance
return samples # (...)


Expand Down
73 changes: 52 additions & 21 deletions src/libtilt/projection/project_fourier.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Tuple

import torch
import torch.nn.functional as F
import einops
Expand Down Expand Up @@ -33,30 +35,12 @@ def project_fourier(
projections: torch.Tensor
`(..., d, d)` array of projection images.
"""
# padding
if pad is True:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft
dft, vol_shape, pad_length = compute_vol_dtf(volume, pad)

# make projections by taking central slices
projections = extract_central_slices_rfft(
dft=dft,
image_shape=volume.shape,
image_shape=vol_shape,
rotation_matrices=rotation_matrices,
rotation_matrix_zyx=rotation_matrix_zyx
) # (..., h, w) rfft
Expand Down Expand Up @@ -92,7 +76,8 @@ def extract_central_slices_rfft(

# flip coordinates in redundant half transform
conjugate_mask = grid[..., 2] < 0
conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3')
# conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3')
conjugate_mask.unsqueeze(-1).repeat(1, 1, 1, 3)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat should be creating a view here and shouldn't be memory intensive even though the tensor is huge - is this not the case?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If I am correct, is expand and not repeat the one that is a view. This change was again a requirement for the compiler.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you're right for the torch API but einops it creates a view where possible - regardless, compilation is super important.

I'm a little hesitant to lose the rank polymorphism here and it looks like this unsqueeze/repeat is specific to b h w 3 rather than ... h w 3 -> could you try adding some code to intepret the current shape and unsqueeze/repeat according to that? This should allow us to maintain the current flexibility and have compatibility with the compiler

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should work conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) and being more memory efficient, since it is a view.

grid[conjugate_mask] *= -1
conjugate_mask = conjugate_mask[..., 0] # un-repeat

Expand All @@ -107,3 +92,49 @@ def extract_central_slices_rfft(
# take complex conjugate of values from redundant half transform
projections[conjugate_mask] = torch.conj(projections[conjugate_mask])
return projections

def compute_vol_dtf( #TODO: Is this the best place to have this?
volume: torch.Tensor,
pad: bool = True,
pad_length: int | None = None
) -> Tuple[torch.Tensor, Tuple[int,int,int], int]:
"""Project a cubic volume by sampling a central slice through its DFT.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this docstring needs fixing


Parameters
----------
volume: torch.Tensor
`(d, d, d)` volume.
pad: bool
Whether to pad the volume with zeros to increase sampling in the DFT.
pad_length: bool
The lenght used for padding. If None, volume.shape[-1] // 2 is used instead

Returns
-------
projections: Tuple[torch.Tensor, torch.Tensor, int]
`(..., d, d, d)` dft of the volume. fftshifted rfft
Tuple[int,int,int] the shape of the volume after padding
int with the padding length
"""
# padding
if pad is True:
if pad_length is None:
pad_length = volume.shape[-1] // 2
volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0)

# premultiply by sinc2
grid = fftfreq_grid(
image_shape=volume.shape,
rfft=False,
fftshift=True,
norm=True,
device=volume.device
)
volume = volume * torch.sinc(grid) ** 2

# calculate DFT
dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin
dft = torch.fft.rfftn(dft, dim=(-3, -2, -1))
dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft

return dft, volume.shape, pad_length
4 changes: 2 additions & 2 deletions src/libtilt/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch
torch_padding = einops.rearrange(torch_padding, 'whd pad -> (whd pad)')
volume = F.pad(volume, pad=tuple(torch_padding), mode='constant', value=0)
padded_volume_shape = (ps, ps, ps)
volume_coordinates = coordinate_grid(image_shape=padded_volume_shape)
volume_coordinates = coordinate_grid(image_shape=padded_volume_shape, device=volume.device)
volume_coordinates -= padded_sidelength // 2 # (d, h, w, zyx)
volume_coordinates = torch.flip(volume_coordinates, dims=(-1,)) # (d, h, w, zyx)
volume_coordinates = einops.rearrange(volume_coordinates, 'd h w zyx -> d h w zyx 1')
Expand All @@ -73,5 +73,5 @@ def _project_volume(rotation_matrix) -> torch.Tensor:

yl, yh = padding[1, 0], -padding[1, 1]
xl, xh = padding[2, 0], -padding[2, 1]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] #TODO: This can probabaly optimized using vmap
return torch.stack(images, dim=0)
2 changes: 1 addition & 1 deletion src/libtilt/shapes/soft_edge.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def _add_soft_edge_single_binary_image(
) -> torch.FloatTensor:
if smoothing_radius == 0:
return image.float()
distances = ndi.distance_transform_edt(torch.logical_not(image))
distances = ndi.distance_transform_edt(torch.logical_not(image)) #TODO: This breaks if the input device is cuda
distances = torch.as_tensor(distances, device=image.device).float()
idx = torch.logical_and(distances > 0, distances <= smoothing_radius)
output = torch.clone(image).float()
Expand Down