From 53dc350d4f8c2ceff22c707f020fc22b56534b51 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Oct 2023 13:50:19 +0100 Subject: [PATCH 01/20] add device project_real --- src/libtilt/projection/project_real.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/projection/project_real.py b/src/libtilt/projection/project_real.py index 36a0940..6ec7cb0 100644 --- a/src/libtilt/projection/project_real.py +++ b/src/libtilt/projection/project_real.py @@ -49,7 +49,7 @@ def project_real(volume: torch.Tensor, rotation_matrices: torch.Tensor) -> torch torch_padding = einops.rearrange(torch_padding, 'whd pad -> (whd pad)') volume = F.pad(volume, pad=tuple(torch_padding), mode='constant', value=0) padded_volume_shape = (ps, ps, ps) - volume_coordinates = coordinate_grid(image_shape=padded_volume_shape) + volume_coordinates = coordinate_grid(image_shape=padded_volume_shape, device=volume.device) volume_coordinates -= padded_sidelength // 2 # (d, h, w, zyx) volume_coordinates = torch.flip(volume_coordinates, dims=(-1,)) # (d, h, w, zyx) volume_coordinates = einops.rearrange(volume_coordinates, 'd h w zyx -> d h w zyx 1') From dfb7adb290aa2cb3040a0abcfebe9762a1146ed2 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Oct 2023 16:23:15 +0100 Subject: [PATCH 02/20] potential bug noted --- src/libtilt/shapes/soft_edge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/shapes/soft_edge.py b/src/libtilt/shapes/soft_edge.py index ef7d456..e125de0 100644 --- a/src/libtilt/shapes/soft_edge.py +++ b/src/libtilt/shapes/soft_edge.py @@ -8,7 +8,7 @@ def _add_soft_edge_single_binary_image( ) -> torch.FloatTensor: if smoothing_radius == 0: return image.float() - distances = ndi.distance_transform_edt(torch.logical_not(image)) + distances = ndi.distance_transform_edt(torch.logical_not(image)) #TODO: This breaks if the input device is cuda distances = torch.as_tensor(distances, device=image.device).float() idx = torch.logical_and(distances > 0, distances <= smoothing_radius) output = torch.clone(image).float() From fc5cd73cea9594b1ab3cbe2d34c96758f0ded0c4 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Oct 2023 18:07:30 +0100 Subject: [PATCH 03/20] performance comments --- src/libtilt/grids/central_slice_grid.py | 2 +- src/libtilt/projection/project_real.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libtilt/grids/central_slice_grid.py b/src/libtilt/grids/central_slice_grid.py index c380abd..07400c9 100644 --- a/src/libtilt/grids/central_slice_grid.py +++ b/src/libtilt/grids/central_slice_grid.py @@ -43,7 +43,7 @@ def rotated_central_slice_grid( device=device, ) # (h, w, 3) if rotation_matrix_zyx is False: - grid = torch.flip(grid, dims=(-1,)) + grid = torch.flip(grid, dims=(-1,)) #TODO: This operation is slow since it is copying the full tensor rotation_matrices = einops.rearrange(rotation_matrices, '... i j -> ... 1 1 i j') grid = einops.rearrange(grid, 'h w coords -> h w coords 1') grid = rotation_matrices @ grid diff --git a/src/libtilt/projection/project_real.py b/src/libtilt/projection/project_real.py index 6ec7cb0..63dbc27 100644 --- a/src/libtilt/projection/project_real.py +++ b/src/libtilt/projection/project_real.py @@ -73,5 +73,5 @@ def _project_volume(rotation_matrix) -> torch.Tensor: yl, yh = padding[1, 0], -padding[1, 1] xl, xh = padding[2, 0], -padding[2, 1] - images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] + images = [_project_volume(matrix)[yl:yh, xl:xh] for matrix in rotation_matrices] #TODO: This can probabaly optimized using vmap return torch.stack(images, dim=0) From 77ba205b9acf33418e6e941b4169e141518a30e0 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Mon, 16 Oct 2023 18:08:51 +0100 Subject: [PATCH 04/20] device in fft_utils --- src/libtilt/fft_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/fft_utils.py b/src/libtilt/fft_utils.py index d02bfe2..202e298 100644 --- a/src/libtilt/fft_utils.py +++ b/src/libtilt/fft_utils.py @@ -25,7 +25,7 @@ def dft_center( fft_center = torch.zeros(size=(len(image_shape),), device=device) image_shape = torch.as_tensor(image_shape).float() if rfft is True: - image_shape = torch.tensor(rfft_shape(image_shape)) + image_shape = torch.tensor(rfft_shape(image_shape), device=device) if fftshifted is True: fft_center = torch.divide(image_shape, 2, rounding_mode='floor') if rfft is True: From 94b5b9b4781cda53bc80b2f0db2dc50cbc9e686a Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Fri, 3 Nov 2023 15:01:52 +0000 Subject: [PATCH 05/20] changes for compiling --- src/libtilt/ctf/ctf_2d.py | 1 + src/libtilt/fft_utils.py | 11 ++- .../interpolation/interpolate_dft_3d.py | 3 +- src/libtilt/projection/project_fourier.py | 73 +++++++++++++------ 4 files changed, 62 insertions(+), 26 deletions(-) diff --git a/src/libtilt/ctf/ctf_2d.py b/src/libtilt/ctf/ctf_2d.py index 2bb8507..c574b3f 100644 --- a/src/libtilt/ctf/ctf_2d.py +++ b/src/libtilt/ctf/ctf_2d.py @@ -56,6 +56,7 @@ def calculate_ctf( Whether to apply fftshift on the resulting CTF images. """ # to torch.Tensor and unit conversions + assert bool(rfft) + bool(fftshift) <= 1, "Error, only one of `rfft` and `fftshift` may be `True`." defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float)) defocus *= 1e4 # micrometers -> angstroms astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float)) diff --git a/src/libtilt/fft_utils.py b/src/libtilt/fft_utils.py index 202e298..1c4d8e1 100644 --- a/src/libtilt/fft_utils.py +++ b/src/libtilt/fft_utils.py @@ -22,10 +22,11 @@ def dft_center( device: torch.device | None = None, ) -> torch.LongTensor: """Return the position of the DFT center for a given input shape.""" + _rfft_shape = rfft_shape(image_shape) fft_center = torch.zeros(size=(len(image_shape),), device=device) image_shape = torch.as_tensor(image_shape).float() if rfft is True: - image_shape = torch.tensor(rfft_shape(image_shape), device=device) + image_shape = torch.tensor(_rfft_shape, device=device) if fftshifted is True: fft_center = torch.divide(image_shape, 2, rounding_mode='floor') if rfft is True: @@ -438,11 +439,13 @@ def fftfreq_to_dft_coordinates( coordinates: torch.Tensor `(..., d)` array of coordinates into a fftshifted DFT. """ + _image_shape = image_shape image_shape = torch.as_tensor( - image_shape, device=frequencies.device, dtype=frequencies.dtype + _image_shape, device=frequencies.device, dtype=frequencies.dtype ) + _rfft_shape = rfft_shape(_image_shape) _rfft_shape = torch.as_tensor( - rfft_shape(image_shape), device=frequencies.device, dtype=frequencies.dtype + _rfft_shape, device=frequencies.device, dtype=frequencies.dtype ) coordinates = torch.empty_like(frequencies) coordinates[..., :-1] = frequencies[..., :-1] * image_shape[:-1] @@ -450,5 +453,5 @@ def fftfreq_to_dft_coordinates( coordinates[..., -1] = frequencies[..., -1] * 2 * (_rfft_shape[-1] - 1) else: coordinates[..., -1] = frequencies[..., -1] * image_shape[-1] - dc = dft_center(image_shape, rfft=rfft, fftshifted=True, device=frequencies.device) + dc = dft_center(_image_shape, rfft=rfft, fftshifted=True, device=frequencies.device) return coordinates + dc diff --git a/src/libtilt/interpolation/interpolate_dft_3d.py b/src/libtilt/interpolation/interpolate_dft_3d.py index e754d13..e4854c2 100644 --- a/src/libtilt/interpolation/interpolate_dft_3d.py +++ b/src/libtilt/interpolation/interpolate_dft_3d.py @@ -49,7 +49,8 @@ def sample_dft_3d( samples = torch.view_as_complex(samples.contiguous()) # (b, ) # pack data back up and return - [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) + # [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) + samples = samples.reshape(*ps) #Ask Alister if this will work in any situation return samples # (...) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index be14b3c..322ce49 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -1,3 +1,5 @@ +from typing import Tuple + import torch import torch.nn.functional as F import einops @@ -33,30 +35,12 @@ def project_fourier( projections: torch.Tensor `(..., d, d)` array of projection images. """ - # padding - if pad is True: - pad_length = volume.shape[-1] // 2 - volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0) - - # premultiply by sinc2 - grid = fftfreq_grid( - image_shape=volume.shape, - rfft=False, - fftshift=True, - norm=True, - device=volume.device - ) - volume = volume * torch.sinc(grid) ** 2 - - # calculate DFT - dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin - dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) - dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft + dft, vol_shape, pad_length = compute_vol_dtf(volume, pad) # make projections by taking central slices projections = extract_central_slices_rfft( dft=dft, - image_shape=volume.shape, + image_shape=vol_shape, rotation_matrices=rotation_matrices, rotation_matrix_zyx=rotation_matrix_zyx ) # (..., h, w) rfft @@ -92,7 +76,8 @@ def extract_central_slices_rfft( # flip coordinates in redundant half transform conjugate_mask = grid[..., 2] < 0 - conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') + # conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') + conjugate_mask.unsqueeze(-1).repeat(1, 1, 1, 3) grid[conjugate_mask] *= -1 conjugate_mask = conjugate_mask[..., 0] # un-repeat @@ -107,3 +92,49 @@ def extract_central_slices_rfft( # take complex conjugate of values from redundant half transform projections[conjugate_mask] = torch.conj(projections[conjugate_mask]) return projections + +def compute_vol_dtf( #TODO: Is this the best place to have this? + volume: torch.Tensor, + pad: bool = True, + pad_length: int | None = None +) -> Tuple[torch.Tensor, Tuple[int,int,int], int]: + """Project a cubic volume by sampling a central slice through its DFT. + + Parameters + ---------- + volume: torch.Tensor + `(d, d, d)` volume. + pad: bool + Whether to pad the volume with zeros to increase sampling in the DFT. + pad_length: bool + The lenght used for padding. If None, volume.shape[-1] // 2 is used instead + + Returns + ------- + projections: Tuple[torch.Tensor, torch.Tensor, int] + `(..., d, d, d)` dft of the volume. fftshifted rfft + Tuple[int,int,int] the shape of the volume after padding + int with the padding length + """ + # padding + if pad is True: + if pad_length is None: + pad_length = volume.shape[-1] // 2 + volume = F.pad(volume, pad=[pad_length] * 6, mode='constant', value=0) + + # premultiply by sinc2 + grid = fftfreq_grid( + image_shape=volume.shape, + rfft=False, + fftshift=True, + norm=True, + device=volume.device + ) + volume = volume * torch.sinc(grid) ** 2 + + # calculate DFT + dft = torch.fft.fftshift(volume, dim=(-3, -2, -1)) # volume center to array origin + dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) + dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft + + return dft, volume.shape, pad_length \ No newline at end of file From b74ed3a0dbe216d3f7094127d293ed82f61a7435 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Fri, 12 Jan 2024 19:42:33 +0000 Subject: [PATCH 06/20] add device --- src/libtilt/ctf/ctf_2d.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/src/libtilt/ctf/ctf_2d.py b/src/libtilt/ctf/ctf_2d.py index c574b3f..51f6b89 100644 --- a/src/libtilt/ctf/ctf_2d.py +++ b/src/libtilt/ctf/ctf_2d.py @@ -22,6 +22,7 @@ def calculate_ctf( image_shape: Tuple[int, int], rfft: bool, fftshift: bool, + device ): """ @@ -57,20 +58,20 @@ def calculate_ctf( """ # to torch.Tensor and unit conversions assert bool(rfft) + bool(fftshift) <= 1, "Error, only one of `rfft` and `fftshift` may be `True`." - defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float)) + defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float, device=device)) defocus *= 1e4 # micrometers -> angstroms - astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float)) + astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float, device=device)) astigmatism *= 1e4 # micrometers -> angstroms - astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float)) + astigmatism_angle = torch.atleast_1d(torch.as_tensor(astigmatism_angle, dtype=torch.float, device=device)) astigmatism_angle *= (C.pi / 180) # degrees -> radians - pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size)) - voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float)) + pixel_size = torch.atleast_1d(torch.as_tensor(pixel_size, device=device)) + voltage = torch.atleast_1d(torch.as_tensor(voltage, dtype=torch.float, device=device)) voltage *= 1e3 # kV -> V spherical_aberration = torch.atleast_1d( - torch.as_tensor(spherical_aberration, dtype=torch.float) + torch.as_tensor(spherical_aberration, dtype=torch.float, device=device) ) spherical_aberration *= 1e7 # mm -> angstroms - image_shape = torch.as_tensor(image_shape) + image_shape = torch.as_tensor(image_shape, device=device) # derived quantities used in CTF calculation defocus_u = defocus + astigmatism @@ -80,10 +81,10 @@ def calculate_ctf( k2 = C.pi / 2 * spherical_aberration * _lambda ** 3 k3 = torch.tensor(np.deg2rad(phase_shift)) k4 = -b_factor / 4 - k5 = np.arctan(amplitude_contrast / np.sqrt(1 - amplitude_contrast ** 2)) + k5 = torch.arctan(amplitude_contrast / torch.sqrt(1 - amplitude_contrast ** 2)) # construct 2D frequency grids and rescale cycles / px -> cycles / Å - fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft) # (h, w, 2) + fftfreq_grid = _construct_fftfreq_grid_2d(image_shape=image_shape, rfft=rfft, device=device) # (h, w, 2) fftfreq_grid = fftfreq_grid / einops.rearrange(pixel_size, 'b -> b 1 1 1') fftfreq_grid_squared = fftfreq_grid ** 2 From ef5f04aa65153172997962e8db5e2da6ebc5a188 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:13:27 -0800 Subject: [PATCH 07/20] Update src/libtilt/ctf/ctf_2d.py --- src/libtilt/ctf/ctf_2d.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/libtilt/ctf/ctf_2d.py b/src/libtilt/ctf/ctf_2d.py index 51f6b89..661cbe9 100644 --- a/src/libtilt/ctf/ctf_2d.py +++ b/src/libtilt/ctf/ctf_2d.py @@ -57,7 +57,8 @@ def calculate_ctf( Whether to apply fftshift on the resulting CTF images. """ # to torch.Tensor and unit conversions - assert bool(rfft) + bool(fftshift) <= 1, "Error, only one of `rfft` and `fftshift` may be `True`." + if bool(rfft) + bool(fftshift) > 1: + raise ValueError("Only one of `rfft` and `fftshift` may be `True`.") defocus = torch.atleast_1d(torch.as_tensor(defocus, dtype=torch.float, device=device)) defocus *= 1e4 # micrometers -> angstroms astigmatism = torch.atleast_1d(torch.as_tensor(astigmatism, dtype=torch.float, device=device)) From 0a1a8d0432626fbb98725955961f89707ee915ab Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:13:32 -0800 Subject: [PATCH 08/20] Update src/libtilt/ctf/ctf_2d.py --- src/libtilt/ctf/ctf_2d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/ctf/ctf_2d.py b/src/libtilt/ctf/ctf_2d.py index 661cbe9..72fe015 100644 --- a/src/libtilt/ctf/ctf_2d.py +++ b/src/libtilt/ctf/ctf_2d.py @@ -22,7 +22,7 @@ def calculate_ctf( image_shape: Tuple[int, int], rfft: bool, fftshift: bool, - device + device: torch.device | None = None ): """ From 3a0feed334fba0a3696556d6872fa5d731648476 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:16:53 -0800 Subject: [PATCH 09/20] Update src/libtilt/grids/central_slice_grid.py --- src/libtilt/grids/central_slice_grid.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/grids/central_slice_grid.py b/src/libtilt/grids/central_slice_grid.py index 07400c9..c380abd 100644 --- a/src/libtilt/grids/central_slice_grid.py +++ b/src/libtilt/grids/central_slice_grid.py @@ -43,7 +43,7 @@ def rotated_central_slice_grid( device=device, ) # (h, w, 3) if rotation_matrix_zyx is False: - grid = torch.flip(grid, dims=(-1,)) #TODO: This operation is slow since it is copying the full tensor + grid = torch.flip(grid, dims=(-1,)) rotation_matrices = einops.rearrange(rotation_matrices, '... i j -> ... 1 1 i j') grid = einops.rearrange(grid, 'h w coords -> h w coords 1') grid = rotation_matrices @ grid From a9f6ad054a3ce32af7ae2b8463e46885b8b87dc6 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:17:29 -0800 Subject: [PATCH 10/20] Update src/libtilt/interpolation/interpolate_dft_3d.py --- src/libtilt/interpolation/interpolate_dft_3d.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/interpolation/interpolate_dft_3d.py b/src/libtilt/interpolation/interpolate_dft_3d.py index e4854c2..ec0a588 100644 --- a/src/libtilt/interpolation/interpolate_dft_3d.py +++ b/src/libtilt/interpolation/interpolate_dft_3d.py @@ -50,7 +50,7 @@ def sample_dft_3d( # pack data back up and return # [samples] = einops.unpack(samples, pattern='*', packed_shapes=ps) - samples = samples.reshape(*ps) #Ask Alister if this will work in any situation + samples = samples.reshape(*ps) # replaces commented line above, for performance return samples # (...) From 32e35f425692b6a11313cc4cb3d59b5faf567705 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:24:25 -0800 Subject: [PATCH 11/20] Update src/libtilt/projection/project_fourier.py --- src/libtilt/projection/project_fourier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index 322ce49..5b20b74 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -107,7 +107,7 @@ def compute_vol_dtf( #TODO: Is this the best place to have this? pad: bool Whether to pad the volume with zeros to increase sampling in the DFT. pad_length: bool - The lenght used for padding. If None, volume.shape[-1] // 2 is used instead + The length used for padding each side of each dimension. If pad_length=None, and pad=True then volume.shape[-1] // 2 is used instead Returns ------- From 71fe0656c0fcd4f45b210b54c532f7718b64ad40 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:24:35 -0800 Subject: [PATCH 12/20] Update src/libtilt/projection/project_fourier.py --- src/libtilt/projection/project_fourier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index 5b20b74..c535479 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -137,4 +137,4 @@ def compute_vol_dtf( #TODO: Is this the best place to have this? dft = torch.fft.rfftn(dft, dim=(-3, -2, -1)) dft = torch.fft.fftshift(dft, dim=(-3, -2,)) # actual fftshift of rfft - return dft, volume.shape, pad_length \ No newline at end of file + return dft, volume.shape, pad_length From 18cf67a8e2848df76a49ee886e203a03b735e3e2 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:24:53 -0800 Subject: [PATCH 13/20] Update src/libtilt/projection/project_fourier.py --- src/libtilt/projection/project_fourier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index c535479..1ac83a3 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -106,7 +106,7 @@ def compute_vol_dtf( #TODO: Is this the best place to have this? `(d, d, d)` volume. pad: bool Whether to pad the volume with zeros to increase sampling in the DFT. - pad_length: bool + pad_length: int | None The length used for padding each side of each dimension. If pad_length=None, and pad=True then volume.shape[-1] // 2 is used instead Returns From 64a5b20dccaea6e4aa1418bace6cb16e13480ece Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:26:24 -0800 Subject: [PATCH 14/20] Update src/libtilt/shapes/soft_edge.py --- src/libtilt/shapes/soft_edge.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/shapes/soft_edge.py b/src/libtilt/shapes/soft_edge.py index e125de0..ef7d456 100644 --- a/src/libtilt/shapes/soft_edge.py +++ b/src/libtilt/shapes/soft_edge.py @@ -8,7 +8,7 @@ def _add_soft_edge_single_binary_image( ) -> torch.FloatTensor: if smoothing_radius == 0: return image.float() - distances = ndi.distance_transform_edt(torch.logical_not(image)) #TODO: This breaks if the input device is cuda + distances = ndi.distance_transform_edt(torch.logical_not(image)) distances = torch.as_tensor(distances, device=image.device).float() idx = torch.logical_and(distances > 0, distances <= smoothing_radius) output = torch.clone(image).float() From 71137ea13f05326d91a6203538f95e13772f5017 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:27:00 -0800 Subject: [PATCH 15/20] Update src/libtilt/projection/project_fourier.py --- src/libtilt/projection/project_fourier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index 1ac83a3..b357663 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -93,7 +93,7 @@ def extract_central_slices_rfft( projections[conjugate_mask] = torch.conj(projections[conjugate_mask]) return projections -def compute_vol_dtf( #TODO: Is this the best place to have this? +def _compute_dtf( volume: torch.Tensor, pad: bool = True, pad_length: int | None = None From 261b7f8c5ba09eb0e21563ee7c8c172bf313ecd5 Mon Sep 17 00:00:00 2001 From: alisterburt Date: Mon, 15 Jan 2024 04:27:05 -0800 Subject: [PATCH 16/20] Update src/libtilt/projection/project_fourier.py --- src/libtilt/projection/project_fourier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index b357663..b4bad18 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -35,7 +35,7 @@ def project_fourier( projections: torch.Tensor `(..., d, d)` array of projection images. """ - dft, vol_shape, pad_length = compute_vol_dtf(volume, pad) + dft, vol_shape, pad_length = _compute_dft(volume, pad) # make projections by taking central slices projections = extract_central_slices_rfft( From 8aa9d8075b9da521fc8ce694f6b21d7740e1cf7c Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Fri, 19 Jan 2024 10:57:40 +0000 Subject: [PATCH 17/20] ignore .idea --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 32720c5..68638a6 100644 --- a/.gitignore +++ b/.gitignore @@ -103,6 +103,7 @@ ENV/ # IDE settings .vscode/ +.idea/ #For PyCharm libtilt/_version.py src/libtilt/_version.py From 8cb66b6a796781fc478eed1116b171a633633d9b Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Fri, 19 Jan 2024 11:32:02 +0000 Subject: [PATCH 18/20] fixing expand --- src/libtilt/projection/project_fourier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index 322ce49..5b64cc3 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -76,8 +76,8 @@ def extract_central_slices_rfft( # flip coordinates in redundant half transform conjugate_mask = grid[..., 2] < 0 - # conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') - conjugate_mask.unsqueeze(-1).repeat(1, 1, 1, 3) + # conjugate_mask = einops.repeat(conjugate_mask, '... -> ... 3') #This operation does not compile + conjugate_mask = conjugate_mask.unsqueeze(-1).expand(*[-1] * len(conjugate_mask.shape), 3) #This does grid[conjugate_mask] *= -1 conjugate_mask = conjugate_mask[..., 0] # un-repeat From d6c308ad17d7e136ec2caf68c2b2bda86af3a23e Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Fri, 19 Jan 2024 11:38:09 +0000 Subject: [PATCH 19/20] dft typo --- src/libtilt/projection/project_fourier.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/libtilt/projection/project_fourier.py b/src/libtilt/projection/project_fourier.py index 456910d..b72a4bc 100644 --- a/src/libtilt/projection/project_fourier.py +++ b/src/libtilt/projection/project_fourier.py @@ -93,12 +93,12 @@ def extract_central_slices_rfft( projections[conjugate_mask] = torch.conj(projections[conjugate_mask]) return projections -def _compute_dtf( +def _compute_dft( volume: torch.Tensor, pad: bool = True, pad_length: int | None = None ) -> Tuple[torch.Tensor, Tuple[int,int,int], int]: - """Project a cubic volume by sampling a central slice through its DFT. + """Computes the DFT of a volume. Intended to be used as a preprocessing before using extract_central_slices_rfft. Parameters ---------- From e23b48be34ac36ae8d16c7ce800ce232bf99c688 Mon Sep 17 00:00:00 2001 From: Ruben Sanchez Garcia Date: Fri, 19 Jan 2024 11:38:49 +0000 Subject: [PATCH 20/20] gitignore --- .gitignore | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index 68638a6..461cc11 100644 --- a/.gitignore +++ b/.gitignore @@ -103,7 +103,7 @@ ENV/ # IDE settings .vscode/ -.idea/ #For PyCharm +.idea/ libtilt/_version.py src/libtilt/_version.py