From 90f09b7b144eabcccc7ba0900a3a681b6fc25a4f Mon Sep 17 00:00:00 2001 From: alisterburt Date: Fri, 9 Feb 2024 13:34:31 -0800 Subject: [PATCH] Update patch extraction API (#60) * rename 2d patch extraction to be more explicitly 2d and spp -> subpixel precision * improve API of 2D patch extraction * first pass at cubic patch extraction from 3D volumes * update API * add 3d tests * improve patch extract on 2d/3d grids API and move into patch extraction subpackage * further improve organisation and naming * fix minor errors in patch extraction and tests * add amplitude spectrum calculation example --- .../on_the_fly_tomogram_amplitude_spectrum.py | 47 ++++++++++ examples/tsa_real_data_multiregion.py | 10 +-- examples/tsa_real_data_prior_shifts.py | 6 +- examples/virtual_tomogram.py | 6 +- src/libtilt/grids/__init__.py | 4 +- src/libtilt/patch_extraction/__init__.py | 4 +- .../_patch_grid_utils/__init__.py | 0 .../_patch_grid_utils/_patch_grid_centers.py | 0 .../_patch_grid_utils/_patch_grid_indices.py | 2 +- .../patch_extraction_on_grid.py} | 12 +-- .../subpixel_cubic_patch_extraction.py | 88 +++++++++++++++++++ ...py => subpixel_square_patch_extraction.py} | 63 +++++++------ ...raction.py => test_patch_extraction_2d.py} | 36 ++++---- .../tests/test_patch_extraction_3d.py | 34 +++++++ 14 files changed, 246 insertions(+), 66 deletions(-) create mode 100644 examples/on_the_fly_tomogram_amplitude_spectrum.py rename src/libtilt/{grids => patch_extraction}/_patch_grid_utils/__init__.py (100%) rename src/libtilt/{grids => patch_extraction}/_patch_grid_utils/_patch_grid_centers.py (100%) rename src/libtilt/{grids => patch_extraction}/_patch_grid_utils/_patch_grid_indices.py (97%) rename src/libtilt/{grids/patch_grid.py => patch_extraction/patch_extraction_on_grid.py} (92%) create mode 100644 src/libtilt/patch_extraction/subpixel_cubic_patch_extraction.py rename src/libtilt/patch_extraction/{patch_extraction_spp.py => subpixel_square_patch_extraction.py} (56%) rename src/libtilt/patch_extraction/tests/{test_patch_extraction.py => test_patch_extraction_2d.py} (56%) create mode 100644 src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py diff --git a/examples/on_the_fly_tomogram_amplitude_spectrum.py b/examples/on_the_fly_tomogram_amplitude_spectrum.py new file mode 100644 index 0000000..13a049c --- /dev/null +++ b/examples/on_the_fly_tomogram_amplitude_spectrum.py @@ -0,0 +1,47 @@ +import mrcfile +import numpy as np +import torch +import time + +from libtilt.rotational_averaging import rotational_average_dft_3d +from libtilt.patch_extraction import extract_cubes + +# https://zenodo.org/records/6504891 +TOMOGRAM_FILE = '/Users/burta2/Downloads/01_10.00Apx.mrc' +N_CUBES = 20 +SIDELENGTH = 128 + +tomogram = torch.tensor(mrcfile.read(TOMOGRAM_FILE), dtype=torch.float32) + +# sample some points in the volume (could be smarter and sample from masked regions) +d, h, w = tomogram.shape +lower_bound = SIDELENGTH // 2 +z = np.random.uniform(low=lower_bound, high=d - lower_bound, size=N_CUBES) +y = np.random.uniform(low=lower_bound, high=h - lower_bound, size=N_CUBES) +x = np.random.uniform(low=lower_bound, high=w - lower_bound, size=N_CUBES) + +zyx = torch.tensor(np.stack([z, y, x], axis=-1), dtype=torch.float32) + +# start timing here +t0 = time.time() + +# extract cubes at those points +cubes = extract_cubes(image=tomogram, positions=zyx, sidelength=SIDELENGTH) + +# calculate amplitude spectra and rotational average +cubes_amplitudes = torch.fft.rfftn(cubes, dim=(-3, -2, -1)).abs().pow(2) +raps, bins = rotational_average_dft_3d(cubes_amplitudes, rfft=True, fftshifted=False, + image_shape=(SIDELENGTH, SIDELENGTH, SIDELENGTH)) +raps = torch.mean(raps, dim=0) # average over each of 10 cubes + +# end timing here +t1 = time.time() + +print(f"Elapsed time: {t1 - t0:.2f} seconds") + +# plot +import matplotlib.pyplot as plt + +fig, ax = plt.subplots() +ax.plot(np.log(raps.numpy())) +plt.show() diff --git a/examples/tsa_real_data_multiregion.py b/examples/tsa_real_data_multiregion.py index 27e71d8..ac64125 100644 --- a/examples/tsa_real_data_multiregion.py +++ b/examples/tsa_real_data_multiregion.py @@ -7,7 +7,7 @@ from libtilt.backprojection import backproject_fourier from libtilt.coordinate_utils import homogenise_coordinates from libtilt.fft_utils import dft_center -from libtilt.patch_extraction import extract_patches +from libtilt.patch_extraction import extract_squares from libtilt.projection import project_fourier from libtilt.rescaling.rescale_fourier import rescale_2d from libtilt.shapes import circle @@ -34,8 +34,8 @@ n_tilts, h, w = tilt_series.shape center = dft_center((h, w), rfft=False, fftshifted=True) center = einops.repeat(center, 'yx -> b yx', b=len(tilt_series)) -tilt_series = extract_patches( - images=tilt_series, +tilt_series = extract_squares( + image=tilt_series, positions=center, sidelength=min(h, w), ) @@ -86,8 +86,8 @@ projected_yx = Mproj @ positions_homogenous.view((-1, 1, 4, 1)) projected_yx = projected_yx.view((8, -1, 2)) - local_ts = extract_patches( - images=tilt_series, + local_ts = extract_squares( + image=tilt_series, positions=projected_yx, sidelength=s ) diff --git a/examples/tsa_real_data_prior_shifts.py b/examples/tsa_real_data_prior_shifts.py index a9598a6..df1cd7c 100644 --- a/examples/tsa_real_data_prior_shifts.py +++ b/examples/tsa_real_data_prior_shifts.py @@ -5,7 +5,7 @@ from libtilt.backprojection import backproject_fourier from libtilt.fft_utils import dft_center -from libtilt.patch_extraction import extract_patches +from libtilt.patch_extraction import extract_squares from libtilt.projection import project_fourier from libtilt.rescaling.rescale_fourier import rescale_2d from libtilt.shapes import circle @@ -21,8 +21,8 @@ tilt_series = torch.as_tensor(mrcfile.read(IMAGE_FILE)) df = pd.read_csv(FID_CENTERS_FILE) -fiducial_tilt_series = extract_patches( - images=tilt_series, +fiducial_tilt_series = extract_squares( + image=tilt_series, positions=torch.tensor(df[['axis-1', 'axis-2']].to_numpy()).float(), sidelength=256 ) diff --git a/examples/virtual_tomogram.py b/examples/virtual_tomogram.py index 8990111..9f60aba 100644 --- a/examples/virtual_tomogram.py +++ b/examples/virtual_tomogram.py @@ -11,7 +11,7 @@ from libtilt.transformations import Rx, Ry, Rz, T, S from libtilt.coordinate_utils import homogenise_coordinates -from libtilt.patch_extraction.patch_extraction_spp import extract_patches +from libtilt.patch_extraction.subpixel_square_patch_extraction import extract_squares from libtilt.rescaling import rescale_2d from libtilt.backprojection import backproject_fourier from libtilt.fft_utils import dft_center @@ -111,8 +111,8 @@ def extract_local_tilt_series( self, target_pixel_size=self.target_pixel_size ) projected_positions = self.calculate_projected_positions(position_in_tomogram) - particle_tilt_series = extract_patches( - images=rescaled_tilt_series, + particle_tilt_series = extract_squares( + image=rescaled_tilt_series, positions=projected_positions, sidelength=sidelength, ) diff --git a/src/libtilt/grids/__init__.py b/src/libtilt/grids/__init__.py index 82dc00c..494165c 100644 --- a/src/libtilt/grids/__init__.py +++ b/src/libtilt/grids/__init__.py @@ -1,12 +1,12 @@ from .coordinate_grid import coordinate_grid from .fftfreq_grid import fftfreq_grid from .central_slice_grid import central_slice_grid, rotated_central_slice_grid -from .patch_grid import patch_grid +from libtilt.patch_extraction.patch_extraction_on_grid import extract_patches_on_grid __all__ = [ 'coordinate_grid', 'fftfreq_grid', 'central_slice_grid', 'rotated_central_slice_grid', - 'patch_grid', + 'extract_patches_on_grid', ] diff --git a/src/libtilt/patch_extraction/__init__.py b/src/libtilt/patch_extraction/__init__.py index 3c192b6..1c48823 100644 --- a/src/libtilt/patch_extraction/__init__.py +++ b/src/libtilt/patch_extraction/__init__.py @@ -1 +1,3 @@ -from .patch_extraction_spp import extract_patches +from .subpixel_square_patch_extraction import extract_squares +from .subpixel_cubic_patch_extraction import extract_cubes +from .patch_extraction_on_grid import extract_patches_on_grid diff --git a/src/libtilt/grids/_patch_grid_utils/__init__.py b/src/libtilt/patch_extraction/_patch_grid_utils/__init__.py similarity index 100% rename from src/libtilt/grids/_patch_grid_utils/__init__.py rename to src/libtilt/patch_extraction/_patch_grid_utils/__init__.py diff --git a/src/libtilt/grids/_patch_grid_utils/_patch_grid_centers.py b/src/libtilt/patch_extraction/_patch_grid_utils/_patch_grid_centers.py similarity index 100% rename from src/libtilt/grids/_patch_grid_utils/_patch_grid_centers.py rename to src/libtilt/patch_extraction/_patch_grid_utils/_patch_grid_centers.py diff --git a/src/libtilt/grids/_patch_grid_utils/_patch_grid_indices.py b/src/libtilt/patch_extraction/_patch_grid_utils/_patch_grid_indices.py similarity index 97% rename from src/libtilt/grids/_patch_grid_utils/_patch_grid_indices.py rename to src/libtilt/patch_extraction/_patch_grid_utils/_patch_grid_indices.py index ad37eed..6389a19 100644 --- a/src/libtilt/grids/_patch_grid_utils/_patch_grid_indices.py +++ b/src/libtilt/patch_extraction/_patch_grid_utils/_patch_grid_indices.py @@ -3,7 +3,7 @@ import einops import torch -from libtilt.grids._patch_grid_utils._patch_grid_centers import _patch_centers_1d +from libtilt.patch_extraction._patch_grid_utils._patch_grid_centers import _patch_centers_1d def patch_grid_indices( diff --git a/src/libtilt/grids/patch_grid.py b/src/libtilt/patch_extraction/patch_extraction_on_grid.py similarity index 92% rename from src/libtilt/grids/patch_grid.py rename to src/libtilt/patch_extraction/patch_extraction_on_grid.py index 58d910b..f4fb537 100644 --- a/src/libtilt/grids/patch_grid.py +++ b/src/libtilt/patch_extraction/patch_extraction_on_grid.py @@ -1,9 +1,9 @@ import torch -from libtilt.grids._patch_grid_utils import patch_grid_centers, patch_grid_indices +from ._patch_grid_utils import patch_grid_centers, patch_grid_indices -def patch_grid( +def extract_patches_on_grid( images: torch.Tensor, patch_shape: tuple[int, int] | tuple[int, int, int], patch_step: tuple[int, int] | tuple[int, int, int], @@ -13,14 +13,14 @@ def patch_grid( raise ValueError('patch shape and step must have the same number of dimensions.') ndim = len(patch_shape) if ndim == 2: - patches, patch_centers = _patch_grid_2d( + patches, patch_centers = _extract_2d_patches_on_2d_grid( images=images, patch_shape=patch_shape, patch_step=patch_step, distribute_patches=distribute_patches, ) elif ndim == 3: - patches, patch_centers = _patch_grid_3d( + patches, patch_centers = _extract_3d_patches_on_3d_grid( images=images, patch_shape=patch_shape, patch_step=patch_step, @@ -31,7 +31,7 @@ def patch_grid( return patches, patch_centers -def _patch_grid_2d( +def _extract_2d_patches_on_2d_grid( images: torch.Tensor, patch_shape: tuple[int, int], patch_step: tuple[int, int], @@ -76,7 +76,7 @@ def _patch_grid_2d( return patches, patch_centers -def _patch_grid_3d( +def _extract_3d_patches_on_3d_grid( images: torch.Tensor, patch_shape: tuple[int, int, int], patch_step: tuple[int, int, int], diff --git a/src/libtilt/patch_extraction/subpixel_cubic_patch_extraction.py b/src/libtilt/patch_extraction/subpixel_cubic_patch_extraction.py new file mode 100644 index 0000000..8d44040 --- /dev/null +++ b/src/libtilt/patch_extraction/subpixel_cubic_patch_extraction.py @@ -0,0 +1,88 @@ +import einops +import torch +from torch.nn import functional as F + +from libtilt.grids import coordinate_grid +from libtilt.shift.shift_image import shift_3d +from libtilt.coordinate_utils import array_to_grid_sample +from libtilt.fft_utils import dft_center + + +def extract_cubes( + image: torch.Tensor, positions: torch.Tensor, sidelength: int, +): + """Extract cubic patches from a 3D image at positions with subpixel precision. + + Patches are extracted at the nearest integer coordinates then phase shifted + such that the requested position is at the center of the patch. + + Parameters + ---------- + image: torch.Tensor + `(d, h, w)` array containing a 3D image. + positions: torch.Tensor + `(..., 3)` + sidelength: int + Sidelength of cubic patches extracted from `image`. + + + Returns + ------- + patches: torch.Tensor + `(..., sidelength, sidelength, sidelength)` array of cubic patches from `image` + with their centers at `positions`. + """ + # pack arbitrary dimensions up into one new batch dim 'b' + positions, ps = einops.pack([positions], pattern='* zyx') + + # extract cubic patches from 3D image + patches = _extract_cubic_patches_from_single_3d_image( + image=image, positions=positions, sidelength=sidelength + ) + + # reassemble patches into arbitrary dimensional stacks + [patches] = einops.unpack(patches, pattern='* d h w', packed_shapes=ps) + return patches + + +def _extract_cubic_patches_from_single_3d_image( + image: torch.Tensor, # (h, w) + positions: torch.Tensor, # (b, 3) zyx + sidelength: int, +) -> torch.Tensor: + d, h, w = image.shape + b, _ = positions.shape + + # find integer positions and shifts to be applied + integer_positions = torch.round(positions) + shifts = integer_positions - positions + + # generate coordinate grids for sampling around each integer position + # add 1px border to leave space for subpixel phase shifting + pd, ph, pw = (sidelength + 2, sidelength + 2, sidelength + 2) + coordinates = coordinate_grid( + image_shape=(pd, ph, pw), + center=dft_center( + image_shape=(pd, ph, pw), + rfft=False, fftshifted=True, + device=image.device + ), + device=image.device + ) # (d, h, w, 3) + broadcastable_positions = einops.rearrange(integer_positions, 'b zyx -> b 1 1 1 zyx') + grid = coordinates + broadcastable_positions # (b, d, h, w, 3) + + # extract patches, grid sample handles boundaries + patches = F.grid_sample( + input=einops.repeat(image, 'd h w -> b 1 d h w', b=b), + grid=array_to_grid_sample(grid, array_shape=(d, h, w)), + mode='nearest', + padding_mode='zeros', + align_corners=True + ) + patches = einops.rearrange(patches, 'b 1 d h w -> b d h w') + + # phase shift to center images then remove border + patches = shift_3d(images=patches, shifts=shifts) + patches = F.pad(patches, pad=(-1, -1, -1, -1, -1, -1)) + return patches diff --git a/src/libtilt/patch_extraction/patch_extraction_spp.py b/src/libtilt/patch_extraction/subpixel_square_patch_extraction.py similarity index 56% rename from src/libtilt/patch_extraction/patch_extraction_spp.py rename to src/libtilt/patch_extraction/subpixel_square_patch_extraction.py index 3c318c5..2e427ef 100644 --- a/src/libtilt/patch_extraction/patch_extraction_spp.py +++ b/src/libtilt/patch_extraction/subpixel_square_patch_extraction.py @@ -8,52 +8,61 @@ from libtilt.fft_utils import dft_center -def extract_patches( - images: torch.Tensor, positions: torch.Tensor, sidelength: int, +def extract_squares( + image: torch.Tensor, positions: torch.Tensor, sidelength: int, ): - """Extract patches from 2D images at positions with subpixel precision. + """Extract square patches from 2D images at positions with subpixel precision. Patches are extracted at the nearest integer coordinates then phase shifted such that the requested position is at the center of the patch. Parameters ---------- - images: torch.Tensor - `(t, h, w)` or `(h, w)` array of 2D images. + image: torch.Tensor + `(h, w)` or `(b, h, w)` array containing a 2D image or 2D images. positions: torch.Tensor - `(..., t, 2)` or `(..., 2)` array of coordinates for patch centers. + `(..., 2)` or `(..., b, 2)` array of coordinates for patch centers. sidelength: int - Sidelength of square 2D patches extracted from `images`. + Sidelength of square patches extracted from `images`. Returns ------- patches: torch.Tensor - `(..., t, sidelength, sidelength)` array of patches from `images` with their - centers at `positions`. + `(..., sidelength, sidelength)` or `(..., b, sidelength, sidelength)` + array of patches from `images` with their centers at `positions`. """ - if images.ndim == 2: - images = einops.rearrange(images, 'h w -> 1 h w') + images_had_batch_dim = True + if image.ndim == 2: # add empty batch dim + images_had_batch_dim = False + image = einops.rearrange(image, 'h w -> 1 h w') positions = einops.rearrange(positions, '... yx -> ... 1 yx') - positions, ps = einops.pack([positions], pattern='* t yx') - positions = einops.rearrange(positions, 'b t yx -> t b yx') - patches = einops.rearrange( - [ - _extract_patches_from_single_image( - image=_image, - positions=_positions, - output_image_sidelength=sidelength - ) - for _image, _positions - in zip(images, positions) - ], - pattern='t b h w -> b t h w' - ) - [patches] = einops.unpack(patches, pattern='* t h w', packed_shapes=ps) + + # pack arbitrary dimensions up into one new batch dim 'b1' + positions, ps = einops.pack([positions], pattern='* b2 yx') + positions = einops.rearrange(positions, 'b1 b2 yx -> b2 b1 yx') + + # extract patches from each 2D image + patches = [ + _extract_square_patches_from_single_2d_image( + image=_image, + positions=_positions, + output_image_sidelength=sidelength + ) + for _image, _positions + in zip(image, positions) + ] + + # reassemble patches into arbitrary dimensional stacks + patches = einops.rearrange(patches, pattern='b2 b1 h w -> b1 b2 h w') + [patches] = einops.unpack(patches, pattern='* b2 h w', packed_shapes=ps) + + if images_had_batch_dim is False: + patches = einops.rearrange(patches, pattern='... 1 h w -> ... h w') return patches -def _extract_patches_from_single_image( +def _extract_square_patches_from_single_2d_image( image: torch.Tensor, # (h, w) positions: torch.Tensor, # (b, 2) yx output_image_sidelength: int, diff --git a/src/libtilt/patch_extraction/tests/test_patch_extraction.py b/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py similarity index 56% rename from src/libtilt/patch_extraction/tests/test_patch_extraction.py rename to src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py index c136bcd..18c1766 100644 --- a/src/libtilt/patch_extraction/tests/test_patch_extraction.py +++ b/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py @@ -1,15 +1,15 @@ import torch -from libtilt.patch_extraction.patch_extraction_spp import extract_patches, \ - _extract_patches_from_single_image +from libtilt.patch_extraction.subpixel_square_patch_extraction import extract_squares, \ + _extract_square_patches_from_single_2d_image -def test_extract_patches_from_single_image(): - """Test particle_extraction from single image.""" +def test_single_square_patch_from_single_image(): + """Test square patch extraction from single image.""" img = torch.zeros((28, 28)) img[::2, ::2] = 1 positions = torch.tensor([14., 14.]).reshape((1, 2)) - patches = _extract_patches_from_single_image( + patches = _extract_square_patches_from_single_2d_image( image=img, positions=positions, output_image_sidelength=4 ) assert patches.shape == (1, 4, 4) @@ -17,16 +17,16 @@ def test_extract_patches_from_single_image(): assert torch.allclose(patches, expected_image, atol=1e-6) -def test_extract_patches(): +def test_extract_square_patches_single(): """Test extracting patches from a stack of images.""" img = torch.zeros((2, 28, 28)) img[:, ::2, ::2] = 1 positions = torch.tensor([[14., 14.], [15., 15.]]).reshape((1, 2, 2)) - patches = extract_patches( - images=img, # (t, h, w) - positions=positions, # (b, t, 2) + patches = extract_squares( + image=img, # (b2, h, w) + positions=positions, # (b1, b2, 2) sidelength=4 - ) # -> (b, t, 4, 4) + ) # -> (b1, b2, 4, 4) assert patches.shape == (1, 2, 4, 4) expected_image_0 = img[0, 12:16, 12:16] expected_image_1 = img[1, 13:17, 13:17] @@ -34,18 +34,18 @@ def test_extract_patches(): assert torch.allclose(patches[0, 1], expected_image_1, atol=1e-6) -def test_extract_patches_single_image(): - """Test particle_extraction from image stack.""" +def test_extract_square_patches_batched(): + """Test batched particle extraction from single image.""" img = torch.zeros((28, 28)) img[::2, ::2] = 1 positions = torch.tensor([[14., 14.], [15., 15.]]) - patches = extract_patches( - images=img, # (h, w) + patches = extract_squares( + image=img, # (h, w) positions=positions, # (b, 2) sidelength=4 - ) # -> (b, 1, 4, 4) - assert patches.shape == (2, 1, 4, 4) + ) # -> (b, 4, 4) + assert patches.shape == (2, 4, 4) expected_image_0 = img[12:16, 12:16] expected_image_1 = img[13:17, 13:17] - assert torch.allclose(patches[0, 0], expected_image_0, atol=1e-6) - assert torch.allclose(patches[1, 0], expected_image_1, atol=1e-6) + assert torch.allclose(patches[0], expected_image_0, atol=1e-6) + assert torch.allclose(patches[1], expected_image_1, atol=1e-6) diff --git a/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py b/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py new file mode 100644 index 0000000..faf951f --- /dev/null +++ b/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py @@ -0,0 +1,34 @@ +import torch + +from libtilt.patch_extraction.subpixel_cubic_patch_extraction import extract_cubes, \ + _extract_cubic_patches_from_single_3d_image + + +def test_single_cubic_patch_from_single_image(): + """Test cubic patch extraction from single 3D image.""" + img = torch.zeros((28, 28, 28)) + img[::2, ::2, ::2] = 1 + positions = torch.tensor([14., 14., 14.]).reshape((1, 3)) + patches = _extract_cubic_patches_from_single_3d_image( + image=img, positions=positions, sidelength=4 + ) + assert patches.shape == (1, 4, 4, 4) + expected_image = img[12:16, 12:16, 12:16] + assert torch.allclose(patches, expected_image, atol=1e-6) + + +def test_extract_cubic_patches(): + """Test extracting cubic patches from a 3D image.""" + img = torch.zeros((28, 28, 28)) + img[::2, ::2, ::2] = 1 + positions = torch.tensor([[14., 14., 14.], [15., 15., 15.]]).reshape((2, 3)) + patches = extract_cubes( + image=img, # (d, h, w) + positions=positions, # (b, 3) + sidelength=4 + ) # -> (b, 4, 4, 4) + assert patches.shape == (2, 4, 4, 4) + expected_image_0 = img[12:16, 12:16, 12:16] + expected_image_1 = img[13:17, 13:17, 13:17] + assert torch.allclose(patches[0], expected_image_0, atol=1e-6) + assert torch.allclose(patches[1], expected_image_1, atol=1e-6)