Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update patch extraction API #60

Merged
merged 9 commits into from
Feb 9, 2024
47 changes: 47 additions & 0 deletions examples/on_the_fly_tomogram_amplitude_spectrum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import mrcfile
import numpy as np
import torch
import time

from libtilt.rotational_averaging import rotational_average_dft_3d
from libtilt.patch_extraction import extract_cubes

# https://zenodo.org/records/6504891
TOMOGRAM_FILE = '/Users/burta2/Downloads/01_10.00Apx.mrc'
N_CUBES = 20
SIDELENGTH = 128

tomogram = torch.tensor(mrcfile.read(TOMOGRAM_FILE), dtype=torch.float32)

# sample some points in the volume (could be smarter and sample from masked regions)
d, h, w = tomogram.shape
lower_bound = SIDELENGTH // 2
z = np.random.uniform(low=lower_bound, high=d - lower_bound, size=N_CUBES)
y = np.random.uniform(low=lower_bound, high=h - lower_bound, size=N_CUBES)
x = np.random.uniform(low=lower_bound, high=w - lower_bound, size=N_CUBES)

zyx = torch.tensor(np.stack([z, y, x], axis=-1), dtype=torch.float32)

# start timing here
t0 = time.time()

# extract cubes at those points
cubes = extract_cubes(image=tomogram, positions=zyx, sidelength=SIDELENGTH)

# calculate amplitude spectra and rotational average
cubes_amplitudes = torch.fft.rfftn(cubes, dim=(-3, -2, -1)).abs().pow(2)
raps, bins = rotational_average_dft_3d(cubes_amplitudes, rfft=True, fftshifted=False,
image_shape=(SIDELENGTH, SIDELENGTH, SIDELENGTH))
raps = torch.mean(raps, dim=0) # average over each of 10 cubes

# end timing here
t1 = time.time()

print(f"Elapsed time: {t1 - t0:.2f} seconds")

# plot
import matplotlib.pyplot as plt

fig, ax = plt.subplots()
ax.plot(np.log(raps.numpy()))
plt.show()
10 changes: 5 additions & 5 deletions examples/tsa_real_data_multiregion.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from libtilt.backprojection import backproject_fourier
from libtilt.coordinate_utils import homogenise_coordinates
from libtilt.fft_utils import dft_center
from libtilt.patch_extraction import extract_patches
from libtilt.patch_extraction import extract_squares
from libtilt.projection import project_fourier
from libtilt.rescaling.rescale_fourier import rescale_2d
from libtilt.shapes import circle
Expand All @@ -34,8 +34,8 @@
n_tilts, h, w = tilt_series.shape
center = dft_center((h, w), rfft=False, fftshifted=True)
center = einops.repeat(center, 'yx -> b yx', b=len(tilt_series))
tilt_series = extract_patches(
images=tilt_series,
tilt_series = extract_squares(
image=tilt_series,
positions=center,
sidelength=min(h, w),
)
Expand Down Expand Up @@ -86,8 +86,8 @@
projected_yx = Mproj @ positions_homogenous.view((-1, 1, 4, 1))
projected_yx = projected_yx.view((8, -1, 2))

local_ts = extract_patches(
images=tilt_series,
local_ts = extract_squares(
image=tilt_series,
positions=projected_yx,
sidelength=s
)
Expand Down
6 changes: 3 additions & 3 deletions examples/tsa_real_data_prior_shifts.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from libtilt.backprojection import backproject_fourier
from libtilt.fft_utils import dft_center
from libtilt.patch_extraction import extract_patches
from libtilt.patch_extraction import extract_squares
from libtilt.projection import project_fourier
from libtilt.rescaling.rescale_fourier import rescale_2d
from libtilt.shapes import circle
Expand All @@ -21,8 +21,8 @@

tilt_series = torch.as_tensor(mrcfile.read(IMAGE_FILE))
df = pd.read_csv(FID_CENTERS_FILE)
fiducial_tilt_series = extract_patches(
images=tilt_series,
fiducial_tilt_series = extract_squares(
image=tilt_series,
positions=torch.tensor(df[['axis-1', 'axis-2']].to_numpy()).float(),
sidelength=256
)
Expand Down
6 changes: 3 additions & 3 deletions examples/virtual_tomogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from libtilt.transformations import Rx, Ry, Rz, T, S
from libtilt.coordinate_utils import homogenise_coordinates
from libtilt.patch_extraction.patch_extraction_spp import extract_patches
from libtilt.patch_extraction.subpixel_square_patch_extraction import extract_squares
from libtilt.rescaling import rescale_2d
from libtilt.backprojection import backproject_fourier
from libtilt.fft_utils import dft_center
Expand Down Expand Up @@ -111,8 +111,8 @@ def extract_local_tilt_series(
self, target_pixel_size=self.target_pixel_size
)
projected_positions = self.calculate_projected_positions(position_in_tomogram)
particle_tilt_series = extract_patches(
images=rescaled_tilt_series,
particle_tilt_series = extract_squares(
image=rescaled_tilt_series,
positions=projected_positions,
sidelength=sidelength,
)
Expand Down
4 changes: 2 additions & 2 deletions src/libtilt/grids/__init__.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from .coordinate_grid import coordinate_grid
from .fftfreq_grid import fftfreq_grid
from .central_slice_grid import central_slice_grid, rotated_central_slice_grid
from .patch_grid import patch_grid
from libtilt.patch_extraction.patch_extraction_on_grid import extract_patches_on_grid

__all__ = [
'coordinate_grid',
'fftfreq_grid',
'central_slice_grid',
'rotated_central_slice_grid',
'patch_grid',
'extract_patches_on_grid',
]
4 changes: 3 additions & 1 deletion src/libtilt/patch_extraction/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .patch_extraction_spp import extract_patches
from .subpixel_square_patch_extraction import extract_squares
from .subpixel_cubic_patch_extraction import extract_cubes
from .patch_extraction_on_grid import extract_patches_on_grid
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import einops
import torch

from libtilt.grids._patch_grid_utils._patch_grid_centers import _patch_centers_1d
from libtilt.patch_extraction._patch_grid_utils._patch_grid_centers import _patch_centers_1d


def patch_grid_indices(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import torch

from libtilt.grids._patch_grid_utils import patch_grid_centers, patch_grid_indices
from ._patch_grid_utils import patch_grid_centers, patch_grid_indices


def patch_grid(
def extract_patches_on_grid(
images: torch.Tensor,
patch_shape: tuple[int, int] | tuple[int, int, int],
patch_step: tuple[int, int] | tuple[int, int, int],
Expand All @@ -13,14 +13,14 @@ def patch_grid(
raise ValueError('patch shape and step must have the same number of dimensions.')
ndim = len(patch_shape)
if ndim == 2:
patches, patch_centers = _patch_grid_2d(
patches, patch_centers = _extract_2d_patches_on_2d_grid(
images=images,
patch_shape=patch_shape,
patch_step=patch_step,
distribute_patches=distribute_patches,
)
elif ndim == 3:
patches, patch_centers = _patch_grid_3d(
patches, patch_centers = _extract_3d_patches_on_3d_grid(
images=images,
patch_shape=patch_shape,
patch_step=patch_step,
Expand All @@ -31,7 +31,7 @@ def patch_grid(
return patches, patch_centers


def _patch_grid_2d(
def _extract_2d_patches_on_2d_grid(
images: torch.Tensor,
patch_shape: tuple[int, int],
patch_step: tuple[int, int],
Expand Down Expand Up @@ -76,7 +76,7 @@ def _patch_grid_2d(
return patches, patch_centers


def _patch_grid_3d(
def _extract_3d_patches_on_3d_grid(
images: torch.Tensor,
patch_shape: tuple[int, int, int],
patch_step: tuple[int, int, int],
Expand Down
88 changes: 88 additions & 0 deletions src/libtilt/patch_extraction/subpixel_cubic_patch_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import einops
import torch
from torch.nn import functional as F

from libtilt.grids import coordinate_grid
from libtilt.shift.shift_image import shift_3d
from libtilt.coordinate_utils import array_to_grid_sample
from libtilt.fft_utils import dft_center


def extract_cubes(
image: torch.Tensor, positions: torch.Tensor, sidelength: int,
):
"""Extract cubic patches from a 3D image at positions with subpixel precision.

Patches are extracted at the nearest integer coordinates then phase shifted
such that the requested position is at the center of the patch.

Parameters
----------
image: torch.Tensor
`(d, h, w)` array containing a 3D image.
positions: torch.Tensor
`(..., 3)`
sidelength: int
Sidelength of cubic patches extracted from `image`.


Returns
-------
patches: torch.Tensor
`(..., sidelength, sidelength, sidelength)` array of cubic patches from `image`
with their centers at `positions`.
"""
# pack arbitrary dimensions up into one new batch dim 'b'
positions, ps = einops.pack([positions], pattern='* zyx')

# extract cubic patches from 3D image
patches = _extract_cubic_patches_from_single_3d_image(
image=image, positions=positions, sidelength=sidelength
)

# reassemble patches into arbitrary dimensional stacks
[patches] = einops.unpack(patches, pattern='* d h w', packed_shapes=ps)
return patches


def _extract_cubic_patches_from_single_3d_image(
image: torch.Tensor, # (h, w)
positions: torch.Tensor, # (b, 3) zyx
sidelength: int,
) -> torch.Tensor:
d, h, w = image.shape
b, _ = positions.shape

# find integer positions and shifts to be applied
integer_positions = torch.round(positions)
shifts = integer_positions - positions

# generate coordinate grids for sampling around each integer position
# add 1px border to leave space for subpixel phase shifting
pd, ph, pw = (sidelength + 2, sidelength + 2, sidelength + 2)
coordinates = coordinate_grid(
image_shape=(pd, ph, pw),
center=dft_center(
image_shape=(pd, ph, pw),
rfft=False, fftshifted=True,
device=image.device
),
device=image.device
) # (d, h, w, 3)
broadcastable_positions = einops.rearrange(integer_positions, 'b zyx -> b 1 1 1 zyx')
grid = coordinates + broadcastable_positions # (b, d, h, w, 3)

# extract patches, grid sample handles boundaries
patches = F.grid_sample(
input=einops.repeat(image, 'd h w -> b 1 d h w', b=b),
grid=array_to_grid_sample(grid, array_shape=(d, h, w)),
mode='nearest',
padding_mode='zeros',
align_corners=True
)
patches = einops.rearrange(patches, 'b 1 d h w -> b d h w')

# phase shift to center images then remove border
patches = shift_3d(images=patches, shifts=shifts)
patches = F.pad(patches, pad=(-1, -1, -1, -1, -1, -1))
return patches
Original file line number Diff line number Diff line change
Expand Up @@ -8,52 +8,61 @@
from libtilt.fft_utils import dft_center


def extract_patches(
images: torch.Tensor, positions: torch.Tensor, sidelength: int,
def extract_squares(
image: torch.Tensor, positions: torch.Tensor, sidelength: int,
):
"""Extract patches from 2D images at positions with subpixel precision.
"""Extract square patches from 2D images at positions with subpixel precision.

Patches are extracted at the nearest integer coordinates then phase shifted
such that the requested position is at the center of the patch.

Parameters
----------
images: torch.Tensor
`(t, h, w)` or `(h, w)` array of 2D images.
image: torch.Tensor
`(h, w)` or `(b, h, w)` array containing a 2D image or 2D images.
positions: torch.Tensor
`(..., t, 2)` or `(..., 2)` array of coordinates for patch centers.
`(..., 2)` or `(..., b, 2)` array of coordinates for patch centers.
sidelength: int
Sidelength of square 2D patches extracted from `images`.
Sidelength of square patches extracted from `images`.


Returns
-------
patches: torch.Tensor
`(..., t, sidelength, sidelength)` array of patches from `images` with their
centers at `positions`.
`(..., sidelength, sidelength)` or `(..., b, sidelength, sidelength)`
array of patches from `images` with their centers at `positions`.
"""
if images.ndim == 2:
images = einops.rearrange(images, 'h w -> 1 h w')
images_had_batch_dim = True
if image.ndim == 2: # add empty batch dim
images_had_batch_dim = False
image = einops.rearrange(image, 'h w -> 1 h w')
positions = einops.rearrange(positions, '... yx -> ... 1 yx')
positions, ps = einops.pack([positions], pattern='* t yx')
positions = einops.rearrange(positions, 'b t yx -> t b yx')
patches = einops.rearrange(
[
_extract_patches_from_single_image(
image=_image,
positions=_positions,
output_image_sidelength=sidelength
)
for _image, _positions
in zip(images, positions)
],
pattern='t b h w -> b t h w'
)
[patches] = einops.unpack(patches, pattern='* t h w', packed_shapes=ps)

# pack arbitrary dimensions up into one new batch dim 'b1'
positions, ps = einops.pack([positions], pattern='* b2 yx')
positions = einops.rearrange(positions, 'b1 b2 yx -> b2 b1 yx')

# extract patches from each 2D image
patches = [
_extract_square_patches_from_single_2d_image(
image=_image,
positions=_positions,
output_image_sidelength=sidelength
)
for _image, _positions
in zip(image, positions)
]

# reassemble patches into arbitrary dimensional stacks
patches = einops.rearrange(patches, pattern='b2 b1 h w -> b1 b2 h w')
[patches] = einops.unpack(patches, pattern='* b2 h w', packed_shapes=ps)

if images_had_batch_dim is False:
patches = einops.rearrange(patches, pattern='... 1 h w -> ... h w')
return patches


def _extract_patches_from_single_image(
def _extract_square_patches_from_single_2d_image(
image: torch.Tensor, # (h, w)
positions: torch.Tensor, # (b, 2) yx
output_image_sidelength: int,
Expand Down
Loading