Skip to content

Commit

Permalink
Projection from 2D -> 1D for common lines (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet authored Feb 29, 2024
1 parent 721efb5 commit 5d45d72
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 12 deletions.
81 changes: 75 additions & 6 deletions examples/coarse_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn.functional as F
import einops
from itertools import combinations
from torch_cubic_spline_grids import CubicBSplineGrid1d

from libtilt.backprojection import backproject_fourier
from libtilt.fft_utils import dft_center
Expand All @@ -12,11 +14,12 @@
from libtilt.shift.shift_image import shift_2d
from libtilt.transformations import Ry, Rz, T
from libtilt.correlation import correlate_2d
from libtilt.projection import project_image_real

IMAGE_FILE = 'data/tomo200528_100.st'
IMAGE_PIXEL_SIZE = 1.724
STAGE_TILT_ANGLE_PRIORS = torch.arange(-51, 51, 3)
TILT_AXIS_ANGLE_PRIOR = -88.7
TILT_AXIS_ANGLE_PRIOR = -30 # -88.7 according to mdoc, but I set it faulty to see if the optimization works
ALIGNMENT_PIXEL_SIZE = 13.79 * 2
# set 0 degree tilt as reference
REFERENCE_TILT = STAGE_TILT_ANGLE_PRIORS.abs().argmin()
Expand Down Expand Up @@ -86,6 +89,59 @@
current_shift += shift
coarse_shifts[i + 1] = current_shift

# create aligned stack for common lines; apply the mask here to prevent recalculation
coarse_aligned_masked = shift_2d(tilt_series, shifts=-coarse_shifts) * coarse_alignment_mask
# generate a weighting for the common line ROI by projecting the mask
mask_weights = project_image_real(coarse_alignment_mask, torch.eye(2).reshape(1, 2, 2))
mask_weights /= mask_weights.max() # normalise to 0 and 1

# optimize tilt axis angle
grid_resolution = 1
tilt_axis_grid = CubicBSplineGrid1d(resolution=grid_resolution, n_channels=1)
tilt_axis_grid.data = torch.tensor([TILT_AXIS_ANGLE_PRIOR, ] * grid_resolution, dtype=torch.float32)
interpolation_points = torch.linspace(0, 1, len(tilt_series))

common_lines_optimiser = torch.optim.Adam(
tilt_axis_grid.parameters(),
lr=1,
)

for epoch in range(200):
# interpolate the grid
tilt_axis_angles = tilt_axis_grid(interpolation_points)

# for common lines each 2d image is projected perpendicular to the tilt axis, thus add 90 degrees
R = Rz(tilt_axis_angles + 90, zyx=False)[:, :2, :2]

projections = []
for i in range(len(coarse_aligned_masked)):
projections.append(
project_image_real(
coarse_aligned_masked[i],
R[i:i+1]
).squeeze()
)
projections = torch.stack(projections)
projections = projections - einops.reduce(projections, 'tilt w -> tilt 1', reduction='mean')
projections = projections / torch.std(projections, dim=(-1), keepdim=True)
# weight the lines by the projected mask
projections = projections * mask_weights

common_lines_optimiser.zero_grad()
squared_differences = (projections - einops.rearrange(projections, 'b d -> b 1 d')) ** 2
loss = einops.reduce(squared_differences, 'b1 b2 d -> 1', reduction='sum')
loss.backward()
common_lines_optimiser.step()

if not (epoch % 10):
print(epoch, loss.item(), tilt_axis_grid.data.mean())

tilt_axis_prediction = tilt_axis_grid(interpolation_points).clone().detach()
print('final tilt axis angle:', torch.unique(tilt_axis_prediction))

# create the aligned stack
coarse_aligned = shift_2d(tilt_series, shifts=-coarse_shifts)

tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True)
tilt_image_center = dft_center(tilt_dimensions, rfft=False, fftshifted=True)

Expand All @@ -96,9 +152,21 @@
M = s2 @ r1 @ r0 @ s0

# coarse reconstruction
coarse_aligned_tilt_series = shift_2d(tilt_series, shifts=-coarse_shifts)
coarse_aligned_reconstruction = backproject_fourier(
images=coarse_aligned_tilt_series,
shifts_only_reconstruction = backproject_fourier(
images=coarse_aligned,
rotation_matrices=torch.linalg.inv(M[:, :3, :3]),
rotation_matrix_zyx=True,
)

s0 = T(-tomogram_center)
r0 = Ry(STAGE_TILT_ANGLE_PRIORS, zyx=True)
r1 = Rz(tilt_axis_prediction, zyx=True)
s2 = T(F.pad(tilt_image_center, pad=(1, 0), value=0))
M = s2 @ r1 @ r0 @ s0

# coarse reconstruction
coarse_reconstruction = backproject_fourier(
images=coarse_aligned,
rotation_matrices=torch.linalg.inv(M[:, :3, :3]),
rotation_matrix_zyx=True,
)
Expand All @@ -107,6 +175,7 @@

viewer = napari.Viewer()
viewer.add_image(tilt_series.detach().numpy(), name='experimental')
viewer.add_image(coarse_aligned_tilt_series.detach().numpy(), name='coarse aligned')
viewer.add_image(coarse_aligned_reconstruction.detach().numpy(), name='coarse reconstruction')
viewer.add_image(coarse_aligned.detach().numpy(), name='coarse aligned')
viewer.add_image(shifts_only_reconstruction.detach().numpy(), name='shifts only reconstruction')
viewer.add_image(coarse_reconstruction.detach().numpy(), name='coarse reconstruction')
napari.run()
2 changes: 1 addition & 1 deletion src/libtilt/projection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .project_fourier import project_fourier
from .project_real import project_real
from .project_real import project_volume_real, project_image_real
72 changes: 70 additions & 2 deletions src/libtilt/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from libtilt.grids.coordinate_grid import coordinate_grid


def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
def project_volume_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
"""Make 2D projections of a 3D volume in specific orientations.
Projections are made by
Expand Down Expand Up @@ -58,7 +58,7 @@ def _project_volume(rotation_matrix) -> torch.Tensor:
rotated_coordinates = rotation_matrix @ volume_coordinates
rotated_coordinates += padded_sidelength // 2
rotated_coordinates = einops.rearrange(rotated_coordinates, 'd h w zyx 1 -> 1 d h w zyx')
rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # zyx -> zyx
rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # zyx -> xyz
rotated_coordinates = array_to_grid_sample(
rotated_coordinates, array_shape=padded_volume_shape
)
Expand All @@ -75,3 +75,71 @@ def _project_volume(rotation_matrix) -> torch.Tensor:
xl, xh = padding[2, 0], -padding[2, 1]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices]
return torch.stack(images, dim=0)


def project_image_real(image: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
"""Make 1D projections of a 2D image in specific orientations.
Projections are made by
1. generating a line of coordinates sufficient to cover
the image in any orientation.
2. left-multiplying `rotation matrices` and coordinate grids to
produce rotated coordinates.
3. sampling `image` at rotated coordinates.
4. summing samples along 'h' dimension of a `(h, w)` image.
The rotation center of `image` is taken to be `torch.tensor(image.shape) // 2`.
Parameters
----------
image: torch.Tensor
`(h, w)` image from which projections will be made.
rotation_matrices: torch.Tensor
`(batch, 2, 2)` array of rotation matrices
Returns
-------
projection_lines: torch.Tensor
`(batch, w)` array of 1D projection lines sampled from `image`.
"""
image = torch.as_tensor(image)
rotation_matrices = torch.as_tensor(rotation_matrices, dtype=torch.float)
image_shape = torch.tensor(image.shape)
ps = padded_sidelength = int(3 ** 0.5 * torch.max(image_shape))
shape_difference = torch.abs(padded_sidelength - image_shape)
padding = torch.empty(size=(2, 2), dtype=torch.int16)
padding[:, 0] = torch.div(shape_difference, 2, rounding_mode='floor')
padding[:, 1] = shape_difference - padding[:, 0]
torch_padding = torch.flip(padding, dims=(0,)) # hw -> wh for torch.nn.functional.pad
torch_padding = einops.rearrange(torch_padding, 'wh pad -> (wh pad)')
image = F.pad(image, pad=tuple(torch_padding), mode='constant', value=0)
padded_image_shape = (ps, ps)
image_coordinates = coordinate_grid(image_shape=padded_image_shape)
image_coordinates -= padded_sidelength // 2 # (h, w, yx)
image_coordinates = torch.flip(image_coordinates, dims=(-1,)) # (h, w, yx)
image_coordinates = einops.rearrange(image_coordinates, 'h w yx -> h w yx 1')

def _project_image(rotation_matrix) -> torch.Tensor:
rotated_coordinates = rotation_matrix @ image_coordinates
rotated_coordinates += padded_sidelength // 2
rotated_coordinates = einops.rearrange(rotated_coordinates, 'h w yx 1 -> 1 h w yx')
rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # yx -> xy
rotated_coordinates = array_to_grid_sample(
rotated_coordinates, array_shape=padded_image_shape
)
samples = F.grid_sample(
input=einops.rearrange(image, 'h w -> 1 1 h w'), # add batch and channel dims
grid=rotated_coordinates,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
)
return einops.reduce(samples, '1 1 h w -> w', reduction='sum')

xl, xh = padding[1, 0], -padding[1, 1]
lines = [_project_image(matrix)[xl:xh] for matrix in rotation_matrices]
return torch.stack(lines, dim=0)
13 changes: 10 additions & 3 deletions src/libtilt/projection/tests/test_project_real.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import torch

from libtilt.projection.project_real import project_real
from libtilt.projection.project_real import project_image_real, project_volume_real


def test_real_space_projection():
def test_real_space_projection_3d():
volume_shape = (2, 10, 10)
volume = torch.arange(2*10*10).reshape(volume_shape).float()
rotation_matrix = torch.eye(3).reshape(1, 3, 3)
projection = project_real(volume, rotation_matrices=rotation_matrix)
projection = project_volume_real(volume, rotation_matrices=rotation_matrix)
assert torch.allclose(projection.squeeze(), torch.sum(volume, dim=0))


def test_real_space_projection_2d():
image_shape = (8, 12)
image = torch.arange(8 * 12).reshape(image_shape).float()
rotation_matrix = torch.eye(2).reshape(1, 2, 2)
projection = project_image_real(image, rotation_matrices=rotation_matrix)
assert torch.allclose(projection.squeeze(), torch.sum(image, dim=0))

0 comments on commit 5d45d72

Please sign in to comment.