diff --git a/examples/coarse_alignment.py b/examples/coarse_alignment.py index c900d5c..7b0977c 100644 --- a/examples/coarse_alignment.py +++ b/examples/coarse_alignment.py @@ -3,6 +3,8 @@ import torch import torch.nn.functional as F import einops +from itertools import combinations +from torch_cubic_spline_grids import CubicBSplineGrid1d from libtilt.backprojection import backproject_fourier from libtilt.fft_utils import dft_center @@ -12,11 +14,12 @@ from libtilt.shift.shift_image import shift_2d from libtilt.transformations import Ry, Rz, T from libtilt.correlation import correlate_2d +from libtilt.projection import project_image_real IMAGE_FILE = 'data/tomo200528_100.st' IMAGE_PIXEL_SIZE = 1.724 STAGE_TILT_ANGLE_PRIORS = torch.arange(-51, 51, 3) -TILT_AXIS_ANGLE_PRIOR = -88.7 +TILT_AXIS_ANGLE_PRIOR = -30 # -88.7 according to mdoc, but I set it faulty to see if the optimization works ALIGNMENT_PIXEL_SIZE = 13.79 * 2 # set 0 degree tilt as reference REFERENCE_TILT = STAGE_TILT_ANGLE_PRIORS.abs().argmin() @@ -86,6 +89,59 @@ current_shift += shift coarse_shifts[i + 1] = current_shift +# create aligned stack for common lines; apply the mask here to prevent recalculation +coarse_aligned_masked = shift_2d(tilt_series, shifts=-coarse_shifts) * coarse_alignment_mask +# generate a weighting for the common line ROI by projecting the mask +mask_weights = project_image_real(coarse_alignment_mask, torch.eye(2).reshape(1, 2, 2)) +mask_weights /= mask_weights.max() # normalise to 0 and 1 + +# optimize tilt axis angle +grid_resolution = 1 +tilt_axis_grid = CubicBSplineGrid1d(resolution=grid_resolution, n_channels=1) +tilt_axis_grid.data = torch.tensor([TILT_AXIS_ANGLE_PRIOR, ] * grid_resolution, dtype=torch.float32) +interpolation_points = torch.linspace(0, 1, len(tilt_series)) + +common_lines_optimiser = torch.optim.Adam( + tilt_axis_grid.parameters(), + lr=1, +) + +for epoch in range(200): + # interpolate the grid + tilt_axis_angles = tilt_axis_grid(interpolation_points) + + # for common lines each 2d image is projected perpendicular to the tilt axis, thus add 90 degrees + R = Rz(tilt_axis_angles + 90, zyx=False)[:, :2, :2] + + projections = [] + for i in range(len(coarse_aligned_masked)): + projections.append( + project_image_real( + coarse_aligned_masked[i], + R[i:i+1] + ).squeeze() + ) + projections = torch.stack(projections) + projections = projections - einops.reduce(projections, 'tilt w -> tilt 1', reduction='mean') + projections = projections / torch.std(projections, dim=(-1), keepdim=True) + # weight the lines by the projected mask + projections = projections * mask_weights + + common_lines_optimiser.zero_grad() + squared_differences = (projections - einops.rearrange(projections, 'b d -> b 1 d')) ** 2 + loss = einops.reduce(squared_differences, 'b1 b2 d -> 1', reduction='sum') + loss.backward() + common_lines_optimiser.step() + + if not (epoch % 10): + print(epoch, loss.item(), tilt_axis_grid.data.mean()) + +tilt_axis_prediction = tilt_axis_grid(interpolation_points).clone().detach() +print('final tilt axis angle:', torch.unique(tilt_axis_prediction)) + +# create the aligned stack +coarse_aligned = shift_2d(tilt_series, shifts=-coarse_shifts) + tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True) tilt_image_center = dft_center(tilt_dimensions, rfft=False, fftshifted=True) @@ -96,9 +152,21 @@ M = s2 @ r1 @ r0 @ s0 # coarse reconstruction -coarse_aligned_tilt_series = shift_2d(tilt_series, shifts=-coarse_shifts) -coarse_aligned_reconstruction = backproject_fourier( - images=coarse_aligned_tilt_series, +shifts_only_reconstruction = backproject_fourier( + images=coarse_aligned, + rotation_matrices=torch.linalg.inv(M[:, :3, :3]), + rotation_matrix_zyx=True, +) + +s0 = T(-tomogram_center) +r0 = Ry(STAGE_TILT_ANGLE_PRIORS, zyx=True) +r1 = Rz(tilt_axis_prediction, zyx=True) +s2 = T(F.pad(tilt_image_center, pad=(1, 0), value=0)) +M = s2 @ r1 @ r0 @ s0 + +# coarse reconstruction +coarse_reconstruction = backproject_fourier( + images=coarse_aligned, rotation_matrices=torch.linalg.inv(M[:, :3, :3]), rotation_matrix_zyx=True, ) @@ -107,6 +175,7 @@ viewer = napari.Viewer() viewer.add_image(tilt_series.detach().numpy(), name='experimental') -viewer.add_image(coarse_aligned_tilt_series.detach().numpy(), name='coarse aligned') -viewer.add_image(coarse_aligned_reconstruction.detach().numpy(), name='coarse reconstruction') +viewer.add_image(coarse_aligned.detach().numpy(), name='coarse aligned') +viewer.add_image(shifts_only_reconstruction.detach().numpy(), name='shifts only reconstruction') +viewer.add_image(coarse_reconstruction.detach().numpy(), name='coarse reconstruction') napari.run() diff --git a/src/libtilt/projection/__init__.py b/src/libtilt/projection/__init__.py index 98e33cb..aa4bb22 100644 --- a/src/libtilt/projection/__init__.py +++ b/src/libtilt/projection/__init__.py @@ -1,2 +1,2 @@ from .project_fourier import project_fourier -from .project_real import project_real +from .project_real import project_volume_real, project_image_real diff --git a/src/libtilt/projection/project_real.py b/src/libtilt/projection/project_real.py index 36a0940..2604167 100644 --- a/src/libtilt/projection/project_real.py +++ b/src/libtilt/projection/project_real.py @@ -8,7 +8,7 @@ from libtilt.grids.coordinate_grid import coordinate_grid -def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor: +def project_volume_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor: """Make 2D projections of a 3D volume in specific orientations. Projections are made by @@ -58,7 +58,7 @@ def _project_volume(rotation_matrix) -> torch.Tensor: rotated_coordinates = rotation_matrix @ volume_coordinates rotated_coordinates += padded_sidelength // 2 rotated_coordinates = einops.rearrange(rotated_coordinates, 'd h w zyx 1 -> 1 d h w zyx') - rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # zyx -> zyx + rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # zyx -> xyz rotated_coordinates = array_to_grid_sample( rotated_coordinates, array_shape=padded_volume_shape ) @@ -75,3 +75,71 @@ def _project_volume(rotation_matrix) -> torch.Tensor: xl, xh = padding[2, 0], -padding[2, 1] images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] return torch.stack(images, dim=0) + + +def project_image_real(image: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor: + """Make 1D projections of a 2D image in specific orientations. + + Projections are made by + + 1. generating a line of coordinates sufficient to cover + the image in any orientation. + + 2. left-multiplying `rotation matrices` and coordinate grids to + produce rotated coordinates. + + 3. sampling `image` at rotated coordinates. + + 4. summing samples along 'h' dimension of a `(h, w)` image. + + The rotation center of `image` is taken to be `torch.tensor(image.shape) // 2`. + + Parameters + ---------- + image: torch.Tensor + `(h, w)` image from which projections will be made. + rotation_matrices: torch.Tensor + `(batch, 2, 2)` array of rotation matrices + + Returns + ------- + projection_lines: torch.Tensor + `(batch, w)` array of 1D projection lines sampled from `image`. + """ + image = torch.as_tensor(image) + rotation_matrices = torch.as_tensor(rotation_matrices, dtype=torch.float) + image_shape = torch.tensor(image.shape) + ps = padded_sidelength = int(3 ** 0.5 * torch.max(image_shape)) + shape_difference = torch.abs(padded_sidelength - image_shape) + padding = torch.empty(size=(2, 2), dtype=torch.int16) + padding[:, 0] = torch.div(shape_difference, 2, rounding_mode='floor') + padding[:, 1] = shape_difference - padding[:, 0] + torch_padding = torch.flip(padding, dims=(0,)) # hw -> wh for torch.nn.functional.pad + torch_padding = einops.rearrange(torch_padding, 'wh pad -> (wh pad)') + image = F.pad(image, pad=tuple(torch_padding), mode='constant', value=0) + padded_image_shape = (ps, ps) + image_coordinates = coordinate_grid(image_shape=padded_image_shape) + image_coordinates -= padded_sidelength // 2 # (h, w, yx) + image_coordinates = torch.flip(image_coordinates, dims=(-1,)) # (h, w, yx) + image_coordinates = einops.rearrange(image_coordinates, 'h w yx -> h w yx 1') + + def _project_image(rotation_matrix) -> torch.Tensor: + rotated_coordinates = rotation_matrix @ image_coordinates + rotated_coordinates += padded_sidelength // 2 + rotated_coordinates = einops.rearrange(rotated_coordinates, 'h w yx 1 -> 1 h w yx') + rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,)) # yx -> xy + rotated_coordinates = array_to_grid_sample( + rotated_coordinates, array_shape=padded_image_shape + ) + samples = F.grid_sample( + input=einops.rearrange(image, 'h w -> 1 1 h w'), # add batch and channel dims + grid=rotated_coordinates, + mode='bilinear', + padding_mode='zeros', + align_corners=True, + ) + return einops.reduce(samples, '1 1 h w -> w', reduction='sum') + + xl, xh = padding[1, 0], -padding[1, 1] + lines = [_project_image(matrix)[xl:xh] for matrix in rotation_matrices] + return torch.stack(lines, dim=0) diff --git a/src/libtilt/projection/tests/test_project_real.py b/src/libtilt/projection/tests/test_project_real.py index c5fab53..1bfe645 100644 --- a/src/libtilt/projection/tests/test_project_real.py +++ b/src/libtilt/projection/tests/test_project_real.py @@ -1,12 +1,19 @@ import torch -from libtilt.projection.project_real import project_real +from libtilt.projection.project_real import project_image_real, project_volume_real -def test_real_space_projection(): +def test_real_space_projection_3d(): volume_shape = (2, 10, 10) volume = torch.arange(2*10*10).reshape(volume_shape).float() rotation_matrix = torch.eye(3).reshape(1, 3, 3) - projection = project_real(volume, rotation_matrices=rotation_matrix) + projection = project_volume_real(volume, rotation_matrices=rotation_matrix) assert torch.allclose(projection.squeeze(), torch.sum(volume, dim=0)) + +def test_real_space_projection_2d(): + image_shape = (8, 12) + image = torch.arange(8 * 12).reshape(image_shape).float() + rotation_matrix = torch.eye(2).reshape(1, 2, 2) + projection = project_image_real(image, rotation_matrices=rotation_matrix) + assert torch.allclose(projection.squeeze(), torch.sum(image, dim=0))