-
Notifications
You must be signed in to change notification settings - Fork 5
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rename 2d patch extraction to be more explicitly 2d and spp -> subpixel precision * improve API of 2D patch extraction * first pass at cubic patch extraction from 3D volumes * update API * add 3d tests * improve patch extract on 2d/3d grids API and move into patch extraction subpackage * further improve organisation and naming * fix minor errors in patch extraction and tests * add amplitude spectrum calculation example
- Loading branch information
1 parent
958e580
commit 90f09b7
Showing
14 changed files
with
246 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import mrcfile | ||
import numpy as np | ||
import torch | ||
import time | ||
|
||
from libtilt.rotational_averaging import rotational_average_dft_3d | ||
from libtilt.patch_extraction import extract_cubes | ||
|
||
# https://zenodo.org/records/6504891 | ||
TOMOGRAM_FILE = '/Users/burta2/Downloads/01_10.00Apx.mrc' | ||
N_CUBES = 20 | ||
SIDELENGTH = 128 | ||
|
||
tomogram = torch.tensor(mrcfile.read(TOMOGRAM_FILE), dtype=torch.float32) | ||
|
||
# sample some points in the volume (could be smarter and sample from masked regions) | ||
d, h, w = tomogram.shape | ||
lower_bound = SIDELENGTH // 2 | ||
z = np.random.uniform(low=lower_bound, high=d - lower_bound, size=N_CUBES) | ||
y = np.random.uniform(low=lower_bound, high=h - lower_bound, size=N_CUBES) | ||
x = np.random.uniform(low=lower_bound, high=w - lower_bound, size=N_CUBES) | ||
|
||
zyx = torch.tensor(np.stack([z, y, x], axis=-1), dtype=torch.float32) | ||
|
||
# start timing here | ||
t0 = time.time() | ||
|
||
# extract cubes at those points | ||
cubes = extract_cubes(image=tomogram, positions=zyx, sidelength=SIDELENGTH) | ||
|
||
# calculate amplitude spectra and rotational average | ||
cubes_amplitudes = torch.fft.rfftn(cubes, dim=(-3, -2, -1)).abs().pow(2) | ||
raps, bins = rotational_average_dft_3d(cubes_amplitudes, rfft=True, fftshifted=False, | ||
image_shape=(SIDELENGTH, SIDELENGTH, SIDELENGTH)) | ||
raps = torch.mean(raps, dim=0) # average over each of 10 cubes | ||
|
||
# end timing here | ||
t1 = time.time() | ||
|
||
print(f"Elapsed time: {t1 - t0:.2f} seconds") | ||
|
||
# plot | ||
import matplotlib.pyplot as plt | ||
|
||
fig, ax = plt.subplots() | ||
ax.plot(np.log(raps.numpy())) | ||
plt.show() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,12 @@ | ||
from .coordinate_grid import coordinate_grid | ||
from .fftfreq_grid import fftfreq_grid | ||
from .central_slice_grid import central_slice_grid, rotated_central_slice_grid | ||
from .patch_grid import patch_grid | ||
from libtilt.patch_extraction.patch_extraction_on_grid import extract_patches_on_grid | ||
|
||
__all__ = [ | ||
'coordinate_grid', | ||
'fftfreq_grid', | ||
'central_slice_grid', | ||
'rotated_central_slice_grid', | ||
'patch_grid', | ||
'extract_patches_on_grid', | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,3 @@ | ||
from .patch_extraction_spp import extract_patches | ||
from .subpixel_square_patch_extraction import extract_squares | ||
from .subpixel_cubic_patch_extraction import extract_cubes | ||
from .patch_extraction_on_grid import extract_patches_on_grid |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
88 changes: 88 additions & 0 deletions
88
src/libtilt/patch_extraction/subpixel_cubic_patch_extraction.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import einops | ||
import torch | ||
from torch.nn import functional as F | ||
|
||
from libtilt.grids import coordinate_grid | ||
from libtilt.shift.shift_image import shift_3d | ||
from libtilt.coordinate_utils import array_to_grid_sample | ||
from libtilt.fft_utils import dft_center | ||
|
||
|
||
def extract_cubes( | ||
image: torch.Tensor, positions: torch.Tensor, sidelength: int, | ||
): | ||
"""Extract cubic patches from a 3D image at positions with subpixel precision. | ||
Patches are extracted at the nearest integer coordinates then phase shifted | ||
such that the requested position is at the center of the patch. | ||
Parameters | ||
---------- | ||
image: torch.Tensor | ||
`(d, h, w)` array containing a 3D image. | ||
positions: torch.Tensor | ||
`(..., 3)` | ||
sidelength: int | ||
Sidelength of cubic patches extracted from `image`. | ||
Returns | ||
------- | ||
patches: torch.Tensor | ||
`(..., sidelength, sidelength, sidelength)` array of cubic patches from `image` | ||
with their centers at `positions`. | ||
""" | ||
# pack arbitrary dimensions up into one new batch dim 'b' | ||
positions, ps = einops.pack([positions], pattern='* zyx') | ||
|
||
# extract cubic patches from 3D image | ||
patches = _extract_cubic_patches_from_single_3d_image( | ||
image=image, positions=positions, sidelength=sidelength | ||
) | ||
|
||
# reassemble patches into arbitrary dimensional stacks | ||
[patches] = einops.unpack(patches, pattern='* d h w', packed_shapes=ps) | ||
return patches | ||
|
||
|
||
def _extract_cubic_patches_from_single_3d_image( | ||
image: torch.Tensor, # (h, w) | ||
positions: torch.Tensor, # (b, 3) zyx | ||
sidelength: int, | ||
) -> torch.Tensor: | ||
d, h, w = image.shape | ||
b, _ = positions.shape | ||
|
||
# find integer positions and shifts to be applied | ||
integer_positions = torch.round(positions) | ||
shifts = integer_positions - positions | ||
|
||
# generate coordinate grids for sampling around each integer position | ||
# add 1px border to leave space for subpixel phase shifting | ||
pd, ph, pw = (sidelength + 2, sidelength + 2, sidelength + 2) | ||
coordinates = coordinate_grid( | ||
image_shape=(pd, ph, pw), | ||
center=dft_center( | ||
image_shape=(pd, ph, pw), | ||
rfft=False, fftshifted=True, | ||
device=image.device | ||
), | ||
device=image.device | ||
) # (d, h, w, 3) | ||
broadcastable_positions = einops.rearrange(integer_positions, 'b zyx -> b 1 1 1 zyx') | ||
grid = coordinates + broadcastable_positions # (b, d, h, w, 3) | ||
|
||
# extract patches, grid sample handles boundaries | ||
patches = F.grid_sample( | ||
input=einops.repeat(image, 'd h w -> b 1 d h w', b=b), | ||
grid=array_to_grid_sample(grid, array_shape=(d, h, w)), | ||
mode='nearest', | ||
padding_mode='zeros', | ||
align_corners=True | ||
) | ||
patches = einops.rearrange(patches, 'b 1 d h w -> b d h w') | ||
|
||
# phase shift to center images then remove border | ||
patches = shift_3d(images=patches, shifts=shifts) | ||
patches = F.pad(patches, pad=(-1, -1, -1, -1, -1, -1)) | ||
return patches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.