Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Projection from 2D -> 1D for common lines #67

Merged
merged 13 commits into from
Feb 29, 2024
81 changes: 75 additions & 6 deletions examples/coarse_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import torch
import torch.nn.functional as F
import einops
from itertools import combinations
from torch_cubic_spline_grids import CubicBSplineGrid1d

from libtilt.backprojection import backproject_fourier
from libtilt.fft_utils import dft_center
Expand All @@ -12,11 +14,12 @@
from libtilt.shift.shift_image import shift_2d
from libtilt.transformations import Ry, Rz, T
from libtilt.correlation import correlate_2d
from libtilt.projection import project_image_real

IMAGE_FILE = 'data/tomo200528_100.st'
IMAGE_PIXEL_SIZE = 1.724
STAGE_TILT_ANGLE_PRIORS = torch.arange(-51, 51, 3)
TILT_AXIS_ANGLE_PRIOR = -88.7
TILT_AXIS_ANGLE_PRIOR = -30 # -88.7 according to mdoc, but I set it faulty to see if the optimization works
ALIGNMENT_PIXEL_SIZE = 13.79 * 2
# set 0 degree tilt as reference
REFERENCE_TILT = STAGE_TILT_ANGLE_PRIORS.abs().argmin()
Expand Down Expand Up @@ -86,6 +89,59 @@
current_shift += shift
coarse_shifts[i + 1] = current_shift

# create aligned stack for common lines; apply the mask here to prevent recalculation
coarse_aligned_masked = shift_2d(tilt_series, shifts=-coarse_shifts) * coarse_alignment_mask
# generate a weighting for the common line ROI by projecting the mask
mask_weights = project_image_real(coarse_alignment_mask, torch.eye(2).reshape(1, 2, 2))
mask_weights /= mask_weights.max() # normalise to 0 and 1

# optimize tilt axis angle
grid_resolution = 1
tilt_axis_grid = CubicBSplineGrid1d(resolution=grid_resolution, n_channels=1)
tilt_axis_grid.data = torch.tensor([TILT_AXIS_ANGLE_PRIOR, ] * grid_resolution, dtype=torch.float32)
interpolation_points = torch.linspace(0, 1, len(tilt_series))

common_lines_optimiser = torch.optim.Adam(
tilt_axis_grid.parameters(),
lr=1,
)

for epoch in range(200):
# interpolate the grid
tilt_axis_angles = tilt_axis_grid(interpolation_points)

# for common lines each 2d image is projected perpendicular to the tilt axis, thus add 90 degrees
R = Rz(tilt_axis_angles + 90, zyx=False)[:, :2, :2]

projections = []
for i in range(len(coarse_aligned_masked)):
projections.append(
project_image_real(
coarse_aligned_masked[i],
R[i:i+1]
).squeeze()
)
projections = torch.stack(projections)
projections = projections - einops.reduce(projections, 'tilt w -> tilt 1', reduction='mean')
projections = projections / torch.std(projections, dim=(-1), keepdim=True)
# weight the lines by the projected mask
projections = projections * mask_weights

common_lines_optimiser.zero_grad()
squared_differences = (projections - einops.rearrange(projections, 'b d -> b 1 d')) ** 2
loss = einops.reduce(squared_differences, 'b1 b2 d -> 1', reduction='sum')
loss.backward()
common_lines_optimiser.step()

if not (epoch % 10):
print(epoch, loss.item(), tilt_axis_grid.data.mean())

tilt_axis_prediction = tilt_axis_grid(interpolation_points).clone().detach()
print('final tilt axis angle:', torch.unique(tilt_axis_prediction))

# create the aligned stack
coarse_aligned = shift_2d(tilt_series, shifts=-coarse_shifts)

tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True)
tilt_image_center = dft_center(tilt_dimensions, rfft=False, fftshifted=True)

Expand All @@ -96,9 +152,21 @@
M = s2 @ r1 @ r0 @ s0

# coarse reconstruction
coarse_aligned_tilt_series = shift_2d(tilt_series, shifts=-coarse_shifts)
coarse_aligned_reconstruction = backproject_fourier(
images=coarse_aligned_tilt_series,
shifts_only_reconstruction = backproject_fourier(
images=coarse_aligned,
rotation_matrices=torch.linalg.inv(M[:, :3, :3]),
rotation_matrix_zyx=True,
)

s0 = T(-tomogram_center)
r0 = Ry(STAGE_TILT_ANGLE_PRIORS, zyx=True)
r1 = Rz(tilt_axis_prediction, zyx=True)
s2 = T(F.pad(tilt_image_center, pad=(1, 0), value=0))
M = s2 @ r1 @ r0 @ s0

# coarse reconstruction
coarse_reconstruction = backproject_fourier(
images=coarse_aligned,
rotation_matrices=torch.linalg.inv(M[:, :3, :3]),
rotation_matrix_zyx=True,
)
Expand All @@ -107,6 +175,7 @@

viewer = napari.Viewer()
viewer.add_image(tilt_series.detach().numpy(), name='experimental')
viewer.add_image(coarse_aligned_tilt_series.detach().numpy(), name='coarse aligned')
viewer.add_image(coarse_aligned_reconstruction.detach().numpy(), name='coarse reconstruction')
viewer.add_image(coarse_aligned.detach().numpy(), name='coarse aligned')
viewer.add_image(shifts_only_reconstruction.detach().numpy(), name='shifts only reconstruction')
viewer.add_image(coarse_reconstruction.detach().numpy(), name='coarse reconstruction')
napari.run()
2 changes: 1 addition & 1 deletion src/libtilt/projection/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .project_fourier import project_fourier
from .project_real import project_real
from .project_real import project_volume_real, project_image_real
72 changes: 70 additions & 2 deletions src/libtilt/projection/project_real.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from libtilt.grids.coordinate_grid import coordinate_grid


def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
def project_volume_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
"""Make 2D projections of a 3D volume in specific orientations.

Projections are made by
Expand Down Expand Up @@ -58,7 +58,7 @@ def _project_volume(rotation_matrix) -> torch.Tensor:
rotated_coordinates = rotation_matrix @ volume_coordinates
rotated_coordinates += padded_sidelength // 2
rotated_coordinates = einops.rearrange(rotated_coordinates, 'd h w zyx 1 -> 1 d h w zyx')
rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # zyx -> zyx
rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # zyx -> xyz
rotated_coordinates = array_to_grid_sample(
rotated_coordinates, array_shape=padded_volume_shape
)
Expand All @@ -75,3 +75,71 @@ def _project_volume(rotation_matrix) -> torch.Tensor:
xl, xh = padding[2, 0], -padding[2, 1]
images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices]
return torch.stack(images, dim=0)


def project_image_real(image: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
"""Make 1D projections of a 2D image in specific orientations.

Projections are made by

1. generating a line of coordinates sufficient to cover
the image in any orientation.

2. left-multiplying `rotation matrices` and coordinate grids to
produce rotated coordinates.

3. sampling `image` at rotated coordinates.

4. summing samples along 'h' dimension of a `(h, w)` image.

The rotation center of `image` is taken to be `torch.tensor(image.shape) // 2`.

Parameters
----------
image: torch.Tensor
`(h, w)` image from which projections will be made.
rotation_matrices: torch.Tensor
`(batch, 2, 2)` array of rotation matrices

Returns
-------
projection_lines: torch.Tensor
`(batch, w)` array of 1D projection lines sampled from `image`.
"""
image = torch.as_tensor(image)
rotation_matrices = torch.as_tensor(rotation_matrices, dtype=torch.float)
image_shape = torch.tensor(image.shape)
ps = padded_sidelength = int(3 ** 0.5 * torch.max(image_shape))
shape_difference = torch.abs(padded_sidelength - image_shape)
padding = torch.empty(size=(2, 2), dtype=torch.int16)
padding[:, 0] = torch.div(shape_difference, 2, rounding_mode='floor')
padding[:, 1] = shape_difference - padding[:, 0]
torch_padding = torch.flip(padding, dims=(0,)) # hw -> wh for torch.nn.functional.pad
torch_padding = einops.rearrange(torch_padding, 'wh pad -> (wh pad)')
image = F.pad(image, pad=tuple(torch_padding), mode='constant', value=0)
padded_image_shape = (ps, ps)
image_coordinates = coordinate_grid(image_shape=padded_image_shape)
image_coordinates -= padded_sidelength // 2 # (h, w, yx)
image_coordinates = torch.flip(image_coordinates, dims=(-1,)) # (h, w, yx)
image_coordinates = einops.rearrange(image_coordinates, 'h w yx -> h w yx 1')

def _project_image(rotation_matrix) -> torch.Tensor:
rotated_coordinates = rotation_matrix @ image_coordinates
rotated_coordinates += padded_sidelength // 2
rotated_coordinates = einops.rearrange(rotated_coordinates, 'h w yx 1 -> 1 h w yx')
rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # yx -> xy
rotated_coordinates = array_to_grid_sample(
rotated_coordinates, array_shape=padded_image_shape
)
samples = F.grid_sample(
input=einops.rearrange(image, 'h w -> 1 1 h w'), # add batch and channel dims
grid=rotated_coordinates,
mode='bilinear',
padding_mode='zeros',
align_corners=True,
)
return einops.reduce(samples, '1 1 h w -> w', reduction='sum')

xl, xh = padding[1, 0], -padding[1, 1]
lines = [_project_image(matrix)[xl:xh] for matrix in rotation_matrices]
return torch.stack(lines, dim=0)
13 changes: 10 additions & 3 deletions src/libtilt/projection/tests/test_project_real.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,19 @@
import torch

from libtilt.projection.project_real import project_real
from libtilt.projection.project_real import project_image_real, project_volume_real


def test_real_space_projection():
def test_real_space_projection_3d():
volume_shape = (2, 10, 10)
volume = torch.arange(2*10*10).reshape(volume_shape).float()
rotation_matrix = torch.eye(3).reshape(1, 3, 3)
projection = project_real(volume, rotation_matrices=rotation_matrix)
projection = project_volume_real(volume, rotation_matrices=rotation_matrix)
assert torch.allclose(projection.squeeze(), torch.sum(volume, dim=0))


def test_real_space_projection_2d():
image_shape = (8, 12)
image = torch.arange(8 * 12).reshape(image_shape).float()
rotation_matrix = torch.eye(2).reshape(1, 2, 2)
projection = project_image_real(image, rotation_matrices=rotation_matrix)
assert torch.allclose(projection.squeeze(), torch.sum(image, dim=0))