From 5d45d72263a192c8fbed1d000a8d19b023724705 Mon Sep 17 00:00:00 2001
From: Marten <58044494+McHaillet@users.noreply.github.com>
Date: Thu, 29 Feb 2024 16:13:15 +0100
Subject: [PATCH] Projection from 2D -> 1D for common lines (#67)

---
 examples/coarse_alignment.py                  | 81 +++++++++++++++++--
 src/libtilt/projection/__init__.py            |  2 +-
 src/libtilt/projection/project_real.py        | 72 ++++++++++++++++-
 .../projection/tests/test_project_real.py     | 13 ++-
 4 files changed, 156 insertions(+), 12 deletions(-)

diff --git a/examples/coarse_alignment.py b/examples/coarse_alignment.py
index c900d5c..7b0977c 100644
--- a/examples/coarse_alignment.py
+++ b/examples/coarse_alignment.py
@@ -3,6 +3,8 @@
 import torch
 import torch.nn.functional as F
 import einops
+from itertools import combinations
+from torch_cubic_spline_grids import CubicBSplineGrid1d
 
 from libtilt.backprojection import backproject_fourier
 from libtilt.fft_utils import dft_center
@@ -12,11 +14,12 @@
 from libtilt.shift.shift_image import shift_2d
 from libtilt.transformations import Ry, Rz, T
 from libtilt.correlation import correlate_2d
+from libtilt.projection import project_image_real
 
 IMAGE_FILE = 'data/tomo200528_100.st'
 IMAGE_PIXEL_SIZE = 1.724
 STAGE_TILT_ANGLE_PRIORS = torch.arange(-51, 51, 3)
-TILT_AXIS_ANGLE_PRIOR = -88.7
+TILT_AXIS_ANGLE_PRIOR = -30  # -88.7 according to mdoc, but I set it faulty to see if the optimization works
 ALIGNMENT_PIXEL_SIZE = 13.79 * 2
 # set 0 degree tilt as reference
 REFERENCE_TILT = STAGE_TILT_ANGLE_PRIORS.abs().argmin()
@@ -86,6 +89,59 @@
     current_shift += shift
     coarse_shifts[i + 1] = current_shift
 
+# create aligned stack for common lines; apply the mask here to prevent recalculation
+coarse_aligned_masked = shift_2d(tilt_series, shifts=-coarse_shifts) * coarse_alignment_mask
+# generate a weighting for the common line ROI by projecting the mask
+mask_weights = project_image_real(coarse_alignment_mask, torch.eye(2).reshape(1, 2, 2))
+mask_weights /= mask_weights.max()  # normalise to 0 and 1
+
+# optimize tilt axis angle
+grid_resolution = 1
+tilt_axis_grid = CubicBSplineGrid1d(resolution=grid_resolution, n_channels=1)
+tilt_axis_grid.data = torch.tensor([TILT_AXIS_ANGLE_PRIOR, ] * grid_resolution, dtype=torch.float32)
+interpolation_points = torch.linspace(0, 1, len(tilt_series))
+
+common_lines_optimiser = torch.optim.Adam(
+    tilt_axis_grid.parameters(),
+    lr=1,
+)
+
+for epoch in range(200):
+    # interpolate the grid
+    tilt_axis_angles = tilt_axis_grid(interpolation_points)
+
+    # for common lines each 2d image is projected perpendicular to the tilt axis, thus add 90 degrees
+    R = Rz(tilt_axis_angles + 90, zyx=False)[:, :2, :2]
+
+    projections = []
+    for i in range(len(coarse_aligned_masked)):
+        projections.append(
+            project_image_real(
+                coarse_aligned_masked[i],
+                R[i:i+1]
+            ).squeeze()
+        )
+    projections = torch.stack(projections)
+    projections = projections - einops.reduce(projections, 'tilt w -> tilt 1', reduction='mean')
+    projections = projections / torch.std(projections, dim=(-1), keepdim=True)
+    # weight the lines by the projected mask
+    projections = projections * mask_weights
+
+    common_lines_optimiser.zero_grad()
+    squared_differences = (projections - einops.rearrange(projections, 'b d -> b 1 d')) ** 2
+    loss = einops.reduce(squared_differences, 'b1 b2 d -> 1', reduction='sum')
+    loss.backward()
+    common_lines_optimiser.step()
+
+    if not (epoch % 10):
+        print(epoch, loss.item(), tilt_axis_grid.data.mean())
+
+tilt_axis_prediction = tilt_axis_grid(interpolation_points).clone().detach()
+print('final tilt axis angle:', torch.unique(tilt_axis_prediction))
+
+# create the aligned stack
+coarse_aligned = shift_2d(tilt_series, shifts=-coarse_shifts)
+
 tomogram_center = dft_center(tomogram_dimensions, rfft=False, fftshifted=True)
 tilt_image_center = dft_center(tilt_dimensions, rfft=False, fftshifted=True)
 
@@ -96,9 +152,21 @@
 M = s2 @ r1 @ r0 @ s0
 
 # coarse reconstruction
-coarse_aligned_tilt_series = shift_2d(tilt_series, shifts=-coarse_shifts)
-coarse_aligned_reconstruction = backproject_fourier(
-    images=coarse_aligned_tilt_series,
+shifts_only_reconstruction = backproject_fourier(
+    images=coarse_aligned,
+    rotation_matrices=torch.linalg.inv(M[:, :3, :3]),
+    rotation_matrix_zyx=True,
+)
+
+s0 = T(-tomogram_center)
+r0 = Ry(STAGE_TILT_ANGLE_PRIORS, zyx=True)
+r1 = Rz(tilt_axis_prediction, zyx=True)
+s2 = T(F.pad(tilt_image_center, pad=(1, 0), value=0))
+M = s2 @ r1 @ r0 @ s0
+
+# coarse reconstruction
+coarse_reconstruction = backproject_fourier(
+    images=coarse_aligned,
     rotation_matrices=torch.linalg.inv(M[:, :3, :3]),
     rotation_matrix_zyx=True,
 )
@@ -107,6 +175,7 @@
 
 viewer = napari.Viewer()
 viewer.add_image(tilt_series.detach().numpy(), name='experimental')
-viewer.add_image(coarse_aligned_tilt_series.detach().numpy(), name='coarse aligned')
-viewer.add_image(coarse_aligned_reconstruction.detach().numpy(), name='coarse reconstruction')
+viewer.add_image(coarse_aligned.detach().numpy(), name='coarse aligned')
+viewer.add_image(shifts_only_reconstruction.detach().numpy(), name='shifts only reconstruction')
+viewer.add_image(coarse_reconstruction.detach().numpy(), name='coarse reconstruction')
 napari.run()
diff --git a/src/libtilt/projection/__init__.py b/src/libtilt/projection/__init__.py
index 98e33cb..aa4bb22 100644
--- a/src/libtilt/projection/__init__.py
+++ b/src/libtilt/projection/__init__.py
@@ -1,2 +1,2 @@
 from .project_fourier import project_fourier
-from .project_real import project_real
+from .project_real import project_volume_real, project_image_real
diff --git a/src/libtilt/projection/project_real.py b/src/libtilt/projection/project_real.py
index 36a0940..2604167 100644
--- a/src/libtilt/projection/project_real.py
+++ b/src/libtilt/projection/project_real.py
@@ -8,7 +8,7 @@
 from libtilt.grids.coordinate_grid import coordinate_grid
 
 
-def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
+def project_volume_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
     """Make 2D projections of a 3D volume in specific orientations.
 
     Projections are made by
@@ -58,7 +58,7 @@ def _project_volume(rotation_matrix) -> torch.Tensor:
         rotated_coordinates = rotation_matrix @ volume_coordinates
         rotated_coordinates += padded_sidelength // 2
         rotated_coordinates = einops.rearrange(rotated_coordinates, 'd h w zyx 1 -> 1 d h w zyx')
-        rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,))  # zyx -> zyx
+        rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,))  # zyx -> xyz
         rotated_coordinates = array_to_grid_sample(
             rotated_coordinates, array_shape=padded_volume_shape
         )
@@ -75,3 +75,71 @@ def _project_volume(rotation_matrix) -> torch.Tensor:
     xl, xh = padding[2, 0], -padding[2, 1]
     images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices]
     return torch.stack(images, dim=0)
+
+
+def project_image_real(image: torch.Tensor, rotation_matrices: torch.Tensor) -> torch.Tensor:
+    """Make 1D projections of a 2D image in specific orientations.
+
+    Projections are made by
+
+    1. generating a line of coordinates sufficient to cover
+       the image in any orientation.
+
+    2. left-multiplying `rotation matrices` and coordinate grids to
+       produce rotated coordinates.
+
+    3. sampling `image` at rotated coordinates.
+
+    4. summing samples along 'h' dimension of a `(h, w)` image.
+
+    The rotation center of `image` is taken to be `torch.tensor(image.shape) // 2`.
+
+    Parameters
+    ----------
+    image: torch.Tensor
+        `(h, w)` image from which projections will be made.
+    rotation_matrices: torch.Tensor
+        `(batch, 2, 2)` array of rotation matrices
+
+    Returns
+    -------
+    projection_lines: torch.Tensor
+        `(batch, w)` array of 1D projection lines sampled from `image`.
+    """
+    image = torch.as_tensor(image)
+    rotation_matrices = torch.as_tensor(rotation_matrices, dtype=torch.float)
+    image_shape = torch.tensor(image.shape)
+    ps = padded_sidelength = int(3 ** 0.5 * torch.max(image_shape))
+    shape_difference = torch.abs(padded_sidelength - image_shape)
+    padding = torch.empty(size=(2, 2), dtype=torch.int16)
+    padding[:, 0] = torch.div(shape_difference, 2, rounding_mode='floor')
+    padding[:, 1] = shape_difference - padding[:, 0]
+    torch_padding = torch.flip(padding, dims=(0,))  # hw -> wh for torch.nn.functional.pad
+    torch_padding = einops.rearrange(torch_padding, 'wh pad -> (wh pad)')
+    image = F.pad(image, pad=tuple(torch_padding), mode='constant', value=0)
+    padded_image_shape = (ps, ps)
+    image_coordinates = coordinate_grid(image_shape=padded_image_shape)
+    image_coordinates -= padded_sidelength // 2  # (h, w, yx)
+    image_coordinates = torch.flip(image_coordinates, dims=(-1,))  # (h, w, yx)
+    image_coordinates = einops.rearrange(image_coordinates, 'h w yx -> h w yx 1')
+
+    def _project_image(rotation_matrix) -> torch.Tensor:
+        rotated_coordinates = rotation_matrix @ image_coordinates
+        rotated_coordinates += padded_sidelength // 2
+        rotated_coordinates = einops.rearrange(rotated_coordinates, 'h w yx 1 -> 1 h w yx')
+        rotated_coordinates = torch.flip(rotated_coordinates, dims=(-1,))  # yx -> xy
+        rotated_coordinates = array_to_grid_sample(
+            rotated_coordinates, array_shape=padded_image_shape
+        )
+        samples = F.grid_sample(
+            input=einops.rearrange(image, 'h w -> 1 1 h w'),  # add batch and channel dims
+            grid=rotated_coordinates,
+            mode='bilinear',
+            padding_mode='zeros',
+            align_corners=True,
+        )
+        return einops.reduce(samples, '1 1 h w -> w', reduction='sum')
+
+    xl, xh = padding[1, 0], -padding[1, 1]
+    lines = [_project_image(matrix)[xl:xh] for matrix in rotation_matrices]
+    return torch.stack(lines, dim=0)
diff --git a/src/libtilt/projection/tests/test_project_real.py b/src/libtilt/projection/tests/test_project_real.py
index c5fab53..1bfe645 100644
--- a/src/libtilt/projection/tests/test_project_real.py
+++ b/src/libtilt/projection/tests/test_project_real.py
@@ -1,12 +1,19 @@
 import torch
 
-from libtilt.projection.project_real import project_real
+from libtilt.projection.project_real import project_image_real, project_volume_real
 
 
-def test_real_space_projection():
+def test_real_space_projection_3d():
     volume_shape = (2, 10, 10)
     volume = torch.arange(2*10*10).reshape(volume_shape).float()
     rotation_matrix = torch.eye(3).reshape(1, 3, 3)
-    projection = project_real(volume, rotation_matrices=rotation_matrix)
+    projection = project_volume_real(volume, rotation_matrices=rotation_matrix)
     assert torch.allclose(projection.squeeze(), torch.sum(volume, dim=0))
 
+
+def test_real_space_projection_2d():
+    image_shape = (8, 12)
+    image = torch.arange(8 * 12).reshape(image_shape).float()
+    rotation_matrix = torch.eye(2).reshape(1, 2, 2)
+    projection = project_image_real(image, rotation_matrices=rotation_matrix)
+    assert torch.allclose(projection.squeeze(), torch.sum(image, dim=0))