From e93ba8310d1ee7e4244ce9595842a2aeee54862a Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 14:30:01 +0200 Subject: [PATCH 01/22] add device test decorator --- src/libtilt/alignment/tests/test_find_shift.py | 2 ++ src/libtilt/pytest_utils.py | 10 ++++++++++ 2 files changed, 12 insertions(+) create mode 100644 src/libtilt/pytest_utils.py diff --git a/src/libtilt/alignment/tests/test_find_shift.py b/src/libtilt/alignment/tests/test_find_shift.py index dce7484..91ca3e7 100644 --- a/src/libtilt/alignment/tests/test_find_shift.py +++ b/src/libtilt/alignment/tests/test_find_shift.py @@ -2,8 +2,10 @@ import pytest from libtilt.alignment import find_image_shift +from libtilt.pytest_utils import device_test +@device_test def test_find_image_shift(): a = torch.zeros((4, 4)) a[1, 1] = 1 diff --git a/src/libtilt/pytest_utils.py b/src/libtilt/pytest_utils.py new file mode 100644 index 0000000..9c6bca1 --- /dev/null +++ b/src/libtilt/pytest_utils.py @@ -0,0 +1,10 @@ +import torch + + +def device_test(test_func): + def wrapper(): + with torch.device('cpu'): + test_func() + with torch.device('meta'): + test_func() + return wrapper From baefa16f38aec893a7afe65e3a7afc94989ce6b3 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 15:18:40 +0200 Subject: [PATCH 02/22] just switch to use cpu and cuda (instead of meta) as it is most robust --- src/libtilt/pytest_utils.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/libtilt/pytest_utils.py b/src/libtilt/pytest_utils.py index 9c6bca1..eae892f 100644 --- a/src/libtilt/pytest_utils.py +++ b/src/libtilt/pytest_utils.py @@ -2,9 +2,9 @@ def device_test(test_func): - def wrapper(): + def decorator(): with torch.device('cpu'): test_func() - with torch.device('meta'): + with torch.device('cuda'): test_func() - return wrapper + return decorator From 1fa70618f54ef6316071621a0bb5a4643ec2a039 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 15:19:47 +0200 Subject: [PATCH 03/22] add test and fixes --- src/libtilt/alignment/find_shift.py | 5 ++--- src/libtilt/correlation/tests/test_correlate.py | 6 ++++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/src/libtilt/alignment/find_shift.py b/src/libtilt/alignment/find_shift.py index bb2089d..d9af489 100644 --- a/src/libtilt/alignment/find_shift.py +++ b/src/libtilt/alignment/find_shift.py @@ -1,5 +1,4 @@ import torch -import numpy as np import torch.nn.functional as F import einops @@ -47,7 +46,7 @@ def find_image_shift( normalize=True ) maximum_idx = torch.tensor( # explicitly put tensor on CPU in case input is on GPU - np.unravel_index(correlation.argmax().cpu(), shape=image_a.shape), + torch.unravel_index(correlation.argmax().cpu(), shape=image_a.shape), device=image_a.device ) shift = maximum_idx - center @@ -76,7 +75,7 @@ def find_image_shift( upsampled.shape, rfft=False, fftshifted=True, device=image_a.device ) upsampled_shift = torch.tensor( - np.unravel_index(upsampled.argmax().cpu(), shape=upsampled.shape), + torch.unravel_index(upsampled.argmax().cpu(), shape=upsampled.shape), device=image_a.device ) - upsampled_center full_shift = shift + upsampled_shift / upsampling_factor diff --git a/src/libtilt/correlation/tests/test_correlate.py b/src/libtilt/correlation/tests/test_correlate.py index 867bb6b..b45993e 100644 --- a/src/libtilt/correlation/tests/test_correlate.py +++ b/src/libtilt/correlation/tests/test_correlate.py @@ -1,17 +1,18 @@ import torch -import numpy as np from libtilt.correlation import correlate_2d, correlate_dft_2d from libtilt.fft_utils import fftshift_2d +from libtilt.pytest_utils import device_test +@device_test def test_correlate_2d(): a = torch.zeros((10, 10)) a[5, 5] = 1 b = torch.zeros((10, 10)) b[6, 6] = 1 cross_correlation = correlate_2d(a, b, normalize=True) - peak_position = np.unravel_index( + peak_position = torch.unravel_index( indices=torch.argmax(cross_correlation), shape=cross_correlation.shape ) shift = torch.as_tensor(peak_position) - torch.tensor([5, 5]) @@ -20,6 +21,7 @@ def test_correlate_2d(): assert torch.allclose(cross_correlation[peak_position], torch.tensor([1.])) +@device_test def test_correlate_dft_2d(): a = torch.zeros((10, 10)) a[5, 5] = 1 From 8159604f71f560dbe2917f6ba5f86ea82f052b91 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 15:22:59 +0200 Subject: [PATCH 04/22] add device tests to _tests --- src/libtilt/_tests/test_coordinate_tils.py | 7 +++++++ src/libtilt/_tests/test_fft_utils.py | 16 ++++++++++++++++ src/libtilt/_tests/test_transformations.py | 5 +++++ 3 files changed, 28 insertions(+) diff --git a/src/libtilt/_tests/test_coordinate_tils.py b/src/libtilt/_tests/test_coordinate_tils.py index 9ef74d1..5936c9d 100644 --- a/src/libtilt/_tests/test_coordinate_tils.py +++ b/src/libtilt/_tests/test_coordinate_tils.py @@ -9,8 +9,10 @@ homogenise_coordinates, ) from libtilt.grids.coordinate_grid import coordinate_grid +from libtilt.pytest_utils import device_test +@device_test def test_array_coordinates_to_grid_sample_coordinates_nd(): array_shape = z, y, x = (4, 8, 12) array_coordinates = einops.rearrange(torch.tensor(np.indices(array_shape)), @@ -27,6 +29,7 @@ def test_array_coordinates_to_grid_sample_coordinates_nd(): assert torch.allclose(grid_sample_coordinates[:, 0, 0, 2], expected_z) +@device_test def test_grid_sample_coordinates_to_array_coordinates_nd(): array_shape = (4, 8, 12) expected_array_coordinates = einops.rearrange( @@ -41,6 +44,7 @@ def test_grid_sample_coordinates_to_array_coordinates_nd(): assert torch.allclose(array_coordinates, expected_array_coordinates) +@device_test def test_add_implied_coordinate_from_dimension(): batch_of_stacked_2d_coords = torch.zeros(size=(1, 5, 2)) # (b, stack, 2) result = add_positional_coordinate(batch_of_stacked_2d_coords, dim=1) @@ -49,6 +53,7 @@ def test_add_implied_coordinate_from_dimension(): assert torch.allclose(result, expected) +@device_test def test_add_implied_coordinate_from_dimension_prepend(): batch_of_stacked_2d_coords = torch.zeros(size=(1, 5, 2)) # (b, stack, 2) result = add_positional_coordinate(batch_of_stacked_2d_coords, dim=1, @@ -58,6 +63,7 @@ def test_add_implied_coordinate_from_dimension_prepend(): assert torch.allclose(result, expected) +@device_test def test_get_grid_coordinates(): coords = coordinate_grid(image_shape=(3, 2)) assert coords.shape == (3, 2, 2) @@ -74,6 +80,7 @@ def test_get_grid_coordinates(): assert torch.allclose(coords, expected) +@device_test def test_homogenise_coordinates(): coords = torch.rand(size=(2, 3)) homogenised = homogenise_coordinates(coords) diff --git a/src/libtilt/_tests/test_fft_utils.py b/src/libtilt/_tests/test_fft_utils.py index 5daa3d2..3761c63 100644 --- a/src/libtilt/_tests/test_fft_utils.py +++ b/src/libtilt/_tests/test_fft_utils.py @@ -15,8 +15,10 @@ fftfreq_to_dft_coordinates, ) from libtilt.grids.fftfreq_grid import _construct_fftfreq_grid_2d +from libtilt.pytest_utils import device_test +@device_test def test_rfft_shape_from_signal_shape(): # even signal_shape = (2, 4, 8, 16) @@ -31,6 +33,7 @@ def test_rfft_shape_from_signal_shape(): assert rfft.shape == rfft_shape(signal_shape) +@device_test def test_construct_fftfreq_grid_2d(): image_shape = (10, 30) # no rfft @@ -46,6 +49,7 @@ def test_construct_fftfreq_grid_2d(): assert torch.allclose(grid[0, :, 1], torch.fft.rfftfreq(30)) +@device_test def test_rfft_to_symmetrised_dft_2d(): image = torch.rand((10, 10)) fft = torch.fft.fftshift(torch.fft.fftn(image, dim=(-2, -1)), dim=(-2, -1)) @@ -54,6 +58,7 @@ def test_rfft_to_symmetrised_dft_2d(): assert torch.allclose(fft, symmetrised_dft[:-1, :-1], atol=1e-7) +@device_test def test_rfft_to_symmetrised_dft_2d_batched(): image = torch.rand((2, 10, 10)) # (b, h, w) fft = torch.fft.fftshift(torch.fft.fftn(image, dim=(-2, -1)), dim=(-2, -1)) @@ -62,6 +67,7 @@ def test_rfft_to_symmetrised_dft_2d_batched(): assert torch.allclose(fft, symmetrised_dft[..., :-1, :-1], atol=1e-7) +@device_test def test_rfft_to_symmetrised_dft_3d(): image = torch.rand((10, 10, 10)) fft_dims = (-3, -2, -1) @@ -72,6 +78,7 @@ def test_rfft_to_symmetrised_dft_3d(): assert torch.allclose(fft, symmetrised_dft[:-1, :-1, :-1], atol=1e-5) +@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] @@ -86,6 +93,7 @@ def test_symmetrised_dft_to_dft_2d(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-6) +@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] @@ -100,6 +108,7 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-6) +@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] @@ -113,6 +122,7 @@ def test_symmetrised_dft_to_rfft_2d(inplace: bool): assert torch.allclose(desymmetrised_rfft, rfft, atol=1e-6) +@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] @@ -127,6 +137,7 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-6) +@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] @@ -143,6 +154,7 @@ def test_symmetrised_dft_to_dft_3d(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-5) +@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] @@ -158,6 +170,7 @@ def test_symmetrised_dft_to_dft_3d_batched(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-5) +@device_test @pytest.mark.parametrize( "fftshifted, rfft, input, expected", [ @@ -176,6 +189,7 @@ def test_fft_center(fftshifted, rfft, input, expected): assert torch.allclose(result, expected.long()) +@device_test def test_fftfreq_to_spatial_frequency(): fftfreq = torch.fft.fftfreq(10) k = fftfreq_to_spatial_frequency(fftfreq, spacing=2) @@ -187,6 +201,7 @@ def test_fftfreq_to_spatial_frequency(): assert torch.allclose(k, expected) +@device_test def test_spatial_frequency_to_fftfreq(): k = torch.fft.fftfreq(10, d=2) fftfreq = spatial_frequency_to_fftfreq(k, spacing=2) @@ -199,6 +214,7 @@ def test_spatial_frequency_to_fftfreq(): assert torch.allclose(fftfreq, expected) +@device_test def test_fftfreq_to_dft_coords(): from libtilt.grids import fftfreq_grid, coordinate_grid diff --git a/src/libtilt/_tests/test_transformations.py b/src/libtilt/_tests/test_transformations.py index 6e7b244..a3c6c2b 100644 --- a/src/libtilt/_tests/test_transformations.py +++ b/src/libtilt/_tests/test_transformations.py @@ -1,8 +1,10 @@ import torch from libtilt.transformations import Rx, Ry, Rz, T +from libtilt.pytest_utils import device_test +@device_test def test_rotation_around_x(): """Rotation of y around x should become z.""" R = Rx(90) @@ -16,6 +18,7 @@ def test_rotation_around_x(): assert torch.allclose(R @ v, expected, atol=1e-6) +@device_test def test_rotation_around_y(): """Rotation of z around y should be x""" R = Ry(90) @@ -29,6 +32,7 @@ def test_rotation_around_y(): assert torch.allclose(R @ v, expected, atol=1e-6) +@device_test def test_rotation_around_z(): """Rotation of x around z should give y.""" R = Rz(90) @@ -42,6 +46,7 @@ def test_rotation_around_z(): assert torch.allclose(R @ v, expected, atol=1e-6) +@device_test def test_translation(): """Translations""" M = T([1, 1, 1]) From fc5d746ed782c0791c8effc4a5747898bc860fac Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:26:00 +0200 Subject: [PATCH 05/22] add correct wrapper to allow parameter passing --- src/libtilt/pytest_utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/libtilt/pytest_utils.py b/src/libtilt/pytest_utils.py index eae892f..f72f225 100644 --- a/src/libtilt/pytest_utils.py +++ b/src/libtilt/pytest_utils.py @@ -1,10 +1,11 @@ import torch +import functools def device_test(test_func): - def decorator(): - with torch.device('cpu'): - test_func() - with torch.device('cuda'): - test_func() - return decorator + @functools.wraps(test_func) + def run_devices(*args, **kwargs): + for device in ('cpu', 'cuda'): + with torch.device(device): + test_func(*args, **kwargs) + return run_devices From 87aaf778d2a1a1bfe6be49d23508991e84552b12 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:26:19 +0200 Subject: [PATCH 06/22] update fft util tests --- src/libtilt/_tests/test_fft_utils.py | 39 ++++++++++++++++------------ 1 file changed, 22 insertions(+), 17 deletions(-) diff --git a/src/libtilt/_tests/test_fft_utils.py b/src/libtilt/_tests/test_fft_utils.py index 3761c63..d4418ed 100644 --- a/src/libtilt/_tests/test_fft_utils.py +++ b/src/libtilt/_tests/test_fft_utils.py @@ -18,6 +18,10 @@ from libtilt.pytest_utils import device_test +# seed the random number generator to ensure tests are consistent +torch.manual_seed(0) + + @device_test def test_rfft_shape_from_signal_shape(): # even @@ -74,15 +78,14 @@ def test_rfft_to_symmetrised_dft_3d(): fft = torch.fft.fftshift(torch.fft.fftn(image, dim=fft_dims), dim=fft_dims) rfft = torch.fft.rfftn(image, dim=(-3, -2, -1)) symmetrised_dft = _rfft_to_symmetrised_dft_3d(rfft) - np.array(fft - symmetrised_dft[:-1, :-1, :-1]) assert torch.allclose(fft, symmetrised_dft[:-1, :-1, :-1], atol=1e-5) -@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_2d(inplace: bool): image = torch.rand((10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -93,11 +96,11 @@ def test_symmetrised_dft_to_dft_2d(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-6) -@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): image = torch.rand((2, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -108,11 +111,11 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-6) -@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_rfft_2d(inplace: bool): image = torch.rand((10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -122,11 +125,11 @@ def test_symmetrised_dft_to_rfft_2d(inplace: bool): assert torch.allclose(desymmetrised_rfft, rfft, atol=1e-6) -@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): image = torch.rand((2, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -137,11 +140,11 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-6) -@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_3d(inplace: bool): image = torch.rand((10, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-3, -2, -1)) @@ -154,11 +157,11 @@ def test_symmetrised_dft_to_dft_3d(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-5) -@device_test @pytest.mark.parametrize( "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_3d_batched(inplace: bool): image = torch.rand((2, 10, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-3, -2, -1)) @@ -170,23 +173,25 @@ def test_symmetrised_dft_to_dft_3d_batched(inplace: bool): assert torch.allclose(desymmetrised_dft, fft, atol=1e-5) -@device_test @pytest.mark.parametrize( "fftshifted, rfft, input, expected", [ - (False, False, (5, 5, 5), torch.tensor([0., 0., 0.])), - (False, True, (5, 5, 5), torch.tensor([0., 0., 0.])), - (True, False, (5, 5, 5), torch.tensor([2., 2., 2.])), - (True, True, (5, 5, 5), torch.tensor([2., 2., 0.])), - (False, False, (4, 4, 4), torch.tensor([0., 0., 0.])), - (False, True, (4, 4, 4), torch.tensor([0., 0., 0.])), - (True, False, (4, 4, 4), torch.tensor([2., 2., 2.])), - (True, True, (4, 4, 4), torch.tensor([2., 2., 0.])), + (False, False, (5, 5, 5), [0., 0., 0.]), + (False, True, (5, 5, 5), [0., 0., 0.]), + (True, False, (5, 5, 5), [2., 2., 2.]), + (True, True, (5, 5, 5), [2., 2., 0.]), + (False, False, (4, 4, 4), [0., 0., 0.]), + (False, True, (4, 4, 4), [0., 0., 0.]), + (True, False, (4, 4, 4), [2., 2., 2.]), + (True, True, (4, 4, 4), [2., 2., 0.]), ], ) +@device_test def test_fft_center(fftshifted, rfft, input, expected): result = dft_center(input, fftshifted=fftshifted, rfft=rfft) - assert torch.allclose(result, expected.long()) + assert torch.allclose(result, torch.tensor( + expected, device=result.device, dtype=result.dtype + )) @device_test From 6bd5c01d2b856cfd8d0987bf3e36c3c1932831a5 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:28:43 +0200 Subject: [PATCH 07/22] add the integration_tests --- .../test_fourier_slice_extraction_insertion.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py b/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py index f082fd6..361c264 100644 --- a/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py +++ b/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py @@ -5,8 +5,10 @@ insert_into_dft_3d from libtilt.grids import rotated_central_slice_grid from libtilt.fft_utils import fftfreq_to_dft_coordinates +from libtilt.pytest_utils import device_test +@device_test def test_fourier_slice_extraction_insertion_cycle_no_rotation(): sidelength = 64 volume = torch.zeros((sidelength, sidelength, sidelength), dtype=torch.complex64) @@ -36,6 +38,7 @@ def test_fourier_slice_extraction_insertion_cycle_no_rotation(): assert torch.allclose(input_slice, output_slice) +@device_test def test_fourier_slice_extraction_insertion_cycle_with_rotation(): sidelength = 64 From 8f11d8645b87df6acb22154c6008e4d13b2dcfc9 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:30:00 +0200 Subject: [PATCH 08/22] add atomic models tests --- src/libtilt/atomic_models/tests/test_coordinates_to_image.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/libtilt/atomic_models/tests/test_coordinates_to_image.py b/src/libtilt/atomic_models/tests/test_coordinates_to_image.py index 47174c2..cb95939 100644 --- a/src/libtilt/atomic_models/tests/test_coordinates_to_image.py +++ b/src/libtilt/atomic_models/tests/test_coordinates_to_image.py @@ -3,8 +3,10 @@ from libtilt.atomic_models.coordinates_to_image import ( coordinates_to_image_2d, coordinates_to_image_3d ) +from libtilt.pytest_utils import device_test +@device_test def test_coordinates_to_image_2d(): coordinates = torch.as_tensor([14, 14]) image = coordinates_to_image_2d(coordinates=coordinates, image_shape=(28, 28)) @@ -14,6 +16,7 @@ def test_coordinates_to_image_2d(): assert torch.allclose(image, expected) +@device_test def test_coordinates_to_image_3d(): coordinates = torch.as_tensor([14, 14, 14]) image = coordinates_to_image_3d(coordinates=coordinates, image_shape=(28, 28, 28)) From c213fd9366f9e78f75f570fcf026877df34b0702 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:34:05 +0200 Subject: [PATCH 09/22] add ctf --- src/libtilt/ctf/tests/test_ctf.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/libtilt/ctf/tests/test_ctf.py b/src/libtilt/ctf/tests/test_ctf.py index 8208c87..cee550f 100644 --- a/src/libtilt/ctf/tests/test_ctf.py +++ b/src/libtilt/ctf/tests/test_ctf.py @@ -3,8 +3,10 @@ from libtilt.ctf.ctf_1d import calculate_ctf as calculate_ctf_1d from libtilt.ctf.relativistic_wavelength import \ calculate_relativistic_electron_wavelength +from libtilt.pytest_utils import device_test +@device_test def test_1d_ctf_single(): result = calculate_ctf_1d( defocus=1.5, @@ -34,6 +36,7 @@ def test_1d_ctf_single(): assert torch.allclose(result[0], expected, atol=1e-4) +@device_test def test_1d_ctf_batch_defocus(): result = calculate_ctf_1d( defocus=[1.5, 2.5], @@ -56,6 +59,7 @@ def test_1d_ctf_batch_defocus(): assert torch.allclose(result, expected, atol=1e-4) +@device_test def test_calculate_relativistic_electron_wavelength(): """Check function matches expected value from literature. From 7f3d0e1353a3514080851a7038abcb94a20802b7 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:38:00 +0200 Subject: [PATCH 10/22] add filters and fixes for device --- src/libtilt/filters/bandpass.py | 3 ++- src/libtilt/filters/bfactors.py | 3 ++- src/libtilt/filters/tests/test_b.py | 3 +++ src/libtilt/filters/tests/test_bandpass.py | 2 ++ 4 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/libtilt/filters/bandpass.py b/src/libtilt/filters/bandpass.py index d0aca67..d75283a 100644 --- a/src/libtilt/filters/bandpass.py +++ b/src/libtilt/filters/bandpass.py @@ -56,6 +56,7 @@ def bandpass_dft( falloff=falloff, image_shape=image_shape, rfft=rfft, - fftshift=fftshifted + fftshift=fftshifted, + device=dft.device ) return dft * filter diff --git a/src/libtilt/filters/bfactors.py b/src/libtilt/filters/bfactors.py index 6440cb2..0f877eb 100644 --- a/src/libtilt/filters/bfactors.py +++ b/src/libtilt/filters/bfactors.py @@ -50,6 +50,7 @@ def bfactor_dft( image_shape=image_shape, pixel_size=pixel_size, rfft=rfft, - fftshift=fftshifted + fftshift=fftshifted, + device=dft.device ) return dft * b_env diff --git a/src/libtilt/filters/tests/test_b.py b/src/libtilt/filters/tests/test_b.py index 77c730e..8c61875 100644 --- a/src/libtilt/filters/tests/test_b.py +++ b/src/libtilt/filters/tests/test_b.py @@ -1,7 +1,10 @@ import torch from libtilt.filters import bfactor_2d +from libtilt.pytest_utils import device_test + +@device_test def test_bfactor_2d(): # Generate an image image = torch.zeros((4,4)) diff --git a/src/libtilt/filters/tests/test_bandpass.py b/src/libtilt/filters/tests/test_bandpass.py index de16207..d7c57d3 100644 --- a/src/libtilt/filters/tests/test_bandpass.py +++ b/src/libtilt/filters/tests/test_bandpass.py @@ -1,8 +1,10 @@ import torch from libtilt.filters.filters import bandpass_filter +from libtilt.pytest_utils import device_test +@device_test def test_bandpass_filter(): filter = bandpass_filter( low=0.2, From 142b9e5556b67c7dd6a14eb2c9e3a251025053ce Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:42:56 +0200 Subject: [PATCH 11/22] fix bandpass test --- src/libtilt/filters/tests/test_bandpass.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/src/libtilt/filters/tests/test_bandpass.py b/src/libtilt/filters/tests/test_bandpass.py index d7c57d3..1e92c7e 100644 --- a/src/libtilt/filters/tests/test_bandpass.py +++ b/src/libtilt/filters/tests/test_bandpass.py @@ -6,6 +6,13 @@ @device_test def test_bandpass_filter(): + freqs = torch.fft.fftfreq(20) + in_band_idx = torch.logical_and(freqs >= 0.2, freqs <= 0.4) + lower_idx = torch.logical_and(freqs >= 0.1, freqs <= 0.2) + upper_idx = freqs > 0.4 + lower_falloff = torch.cos((torch.pi / 2) * ((freqs[lower_idx] - 0.2) / 0.1)) + upper_falloff = torch.cos((torch.pi / 2) * ((freqs[upper_idx] - 0.4) / 0.1)) + filter = bandpass_filter( low=0.2, high=0.4, @@ -13,13 +20,9 @@ def test_bandpass_filter(): image_shape=(20, 1), rfft=False, fftshift=False, + device=freqs.device ) - freqs = torch.fft.fftfreq(20) - in_band_idx = torch.logical_and(freqs >= 0.2, freqs <= 0.4) - lower_idx = torch.logical_and(freqs >= 0.1, freqs <= 0.2) - upper_idx = freqs > 0.4 - lower_falloff = torch.cos((torch.pi / 2) * ((freqs[lower_idx] - 0.2) / 0.1)) - upper_falloff = torch.cos((torch.pi / 2) * ((freqs[upper_idx] - 0.4) / 0.1)) + assert torch.all(filter[in_band_idx] == 1) assert torch.allclose(filter[lower_idx], lower_falloff.view((-1, 1)), atol=1e-6) assert torch.allclose(filter[upper_idx], upper_falloff.view((-1, 1)), atol=1e-6) From 8ab4b55cd2c3e081040fa707a88caf61b07e277a Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:52:08 +0200 Subject: [PATCH 12/22] add fsc and fix for split idx --- src/libtilt/fsc/_tests/test_fsc.py | 5 +++++ src/libtilt/fsc/_tests/test_fsc_conical.py | 2 ++ src/libtilt/fsc/fsc.py | 3 ++- 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/libtilt/fsc/_tests/test_fsc.py b/src/libtilt/fsc/_tests/test_fsc.py index dcbd433..54ae86e 100644 --- a/src/libtilt/fsc/_tests/test_fsc.py +++ b/src/libtilt/fsc/_tests/test_fsc.py @@ -1,20 +1,24 @@ import torch from libtilt.fsc.fsc import fsc +from libtilt.pytest_utils import device_test +@device_test def test_fsc_identical_images(): a = torch.rand(size=(10, 10)) result = fsc(a, a) assert torch.allclose(result, torch.ones(6)) +@device_test def test_fsc_identical_volumes(): a = torch.rand(size=(10, 10, 10)) result = fsc(a, a) assert torch.allclose(result, torch.ones(6)) +@device_test def test_fsc_identical_images_with_index_subset(): a = torch.rand(size=(10, 10)) rfft_mask = torch.zeros(size=(10, 6), dtype=torch.bool) @@ -24,6 +28,7 @@ def test_fsc_identical_images_with_index_subset(): assert torch.allclose(result, torch.ones(6)) +@device_test def test_fsc_identical_volumes_with_index_subset(): a = torch.rand(size=(10, 10, 10)) rfft_mask = torch.zeros(size=(10, 10, 6), dtype=torch.bool) diff --git a/src/libtilt/fsc/_tests/test_fsc_conical.py b/src/libtilt/fsc/_tests/test_fsc_conical.py index 0970fd1..9fa401f 100644 --- a/src/libtilt/fsc/_tests/test_fsc_conical.py +++ b/src/libtilt/fsc/_tests/test_fsc_conical.py @@ -1,8 +1,10 @@ import torch from libtilt.fsc.fsc_conical import fsc_conical +from libtilt.pytest_utils import device_test +@device_test def test_fsc_conical(): a = torch.rand((10, 10, 10)) result = fsc_conical(a, a, cone_direction=(1, 0, 0), cone_aperture=30) diff --git a/src/libtilt/fsc/fsc.py b/src/libtilt/fsc/fsc.py index 7f6f3e7..6648988 100644 --- a/src/libtilt/fsc/fsc.py +++ b/src/libtilt/fsc/fsc.py @@ -47,7 +47,8 @@ def fsc( # find indices of all components in each shell sorted_frequencies, sort_idx = torch.sort(frequencies, descending=False) split_idx = torch.searchsorted(sorted_frequencies, split_points) - shell_idx = torch.tensor_split(sort_idx, split_idx)[:-1] + # tensor_split requires the split_idx to live on cpu + shell_idx = torch.tensor_split(sort_idx, split_idx.to('cpu'))[:-1] # calculate normalised cross correlation in each shell fsc = [ From 3eb93c4664638922e92b8665922d6bf6e8572d97 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:53:36 +0200 Subject: [PATCH 13/22] add grids tests --- src/libtilt/grids/tests/test_central_slice.py | 2 ++ src/libtilt/grids/tests/test_coordinate_grid.py | 6 ++++++ src/libtilt/grids/tests/test_fftfreq_grid.py | 2 ++ 3 files changed, 10 insertions(+) diff --git a/src/libtilt/grids/tests/test_central_slice.py b/src/libtilt/grids/tests/test_central_slice.py index 8806f9f..552135d 100644 --- a/src/libtilt/grids/tests/test_central_slice.py +++ b/src/libtilt/grids/tests/test_central_slice.py @@ -1,8 +1,10 @@ import torch from libtilt.grids import central_slice_grid +from libtilt.pytest_utils import device_test +@device_test def test_central_slice_grid(): input_shape = (6, 6, 6) diff --git a/src/libtilt/grids/tests/test_coordinate_grid.py b/src/libtilt/grids/tests/test_coordinate_grid.py index f2d1980..6d8778f 100644 --- a/src/libtilt/grids/tests/test_coordinate_grid.py +++ b/src/libtilt/grids/tests/test_coordinate_grid.py @@ -2,8 +2,10 @@ import torch from libtilt.grids import coordinate_grid +from libtilt.pytest_utils import device_test +@device_test def test_coordinate_grid_simple(): image_shape = (5, 3, 2) result = coordinate_grid( @@ -15,6 +17,7 @@ def test_coordinate_grid_simple(): assert torch.allclose(result[4, 2, 1], torch.tensor([4, 2, 1], dtype=torch.float)) +@device_test def test_coordinate_grid_centered(): image_shape = (28, 28) result = coordinate_grid( @@ -25,6 +28,7 @@ def test_coordinate_grid_centered(): assert torch.allclose(result[0, 0], torch.tensor([-14, -14], dtype=torch.float)) +@device_test def test_coordinate_grid_centered_batched(): image_shape = (28, 28) centers = [[0, 0], [14, 14]] @@ -38,6 +42,7 @@ def test_coordinate_grid_centered_batched(): torch.as_tensor([-14, -14], dtype=torch.float)) +@device_test def test_coordinate_grid_centered_stacked(): image_shape = (28, 28) centers = [[0, 0], [14, 14]] @@ -51,6 +56,7 @@ def test_coordinate_grid_centered_stacked(): assert torch.allclose(result[1, 0, 0, 0, 0], torch.as_tensor([-14, -14]).float()) +@device_test def test_coordinate_with_norm(): image_shape = (5, 5) result = coordinate_grid( diff --git a/src/libtilt/grids/tests/test_fftfreq_grid.py b/src/libtilt/grids/tests/test_fftfreq_grid.py index d1d2dfa..4dfe7b1 100644 --- a/src/libtilt/grids/tests/test_fftfreq_grid.py +++ b/src/libtilt/grids/tests/test_fftfreq_grid.py @@ -1,8 +1,10 @@ import torch from libtilt.grids import fftfreq_grid, central_slice_grid +from libtilt.pytest_utils import device_test +@device_test def test_fftfreq_grid_2d(): input_shape = (6, 6) grid = fftfreq_grid( From 3d0bb1cb06e6a55cb0fbf0af9ef51eb8d3672edd Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:54:34 +0200 Subject: [PATCH 14/22] add interpolation --- src/libtilt/interpolation/tests/test_interpolate_dft_3d.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py b/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py index aa412dd..49836a3 100644 --- a/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py +++ b/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py @@ -3,8 +3,10 @@ import torch.nn.functional as F from libtilt.interpolation.interpolate_dft_3d import sample_dft_3d +from libtilt.pytest_utils import device_test +@device_test def test_extract_slices(): volume = torch.zeros(4, 4, 4, dtype=torch.complex64) From 8173fc9cf3f91e7493482ed8553b256999947b46 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:55:58 +0200 Subject: [PATCH 15/22] add patch-extraction --- .../patch_extraction/tests/test_patch_extraction_2d.py | 4 ++++ .../patch_extraction/tests/test_patch_extraction_3d.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py b/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py index 18c1766..598bc72 100644 --- a/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py +++ b/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py @@ -2,8 +2,10 @@ from libtilt.patch_extraction.subpixel_square_patch_extraction import extract_squares, \ _extract_square_patches_from_single_2d_image +from libtilt.pytest_utils import device_test +@device_test def test_single_square_patch_from_single_image(): """Test square patch extraction from single image.""" img = torch.zeros((28, 28)) @@ -17,6 +19,7 @@ def test_single_square_patch_from_single_image(): assert torch.allclose(patches, expected_image, atol=1e-6) +@device_test def test_extract_square_patches_single(): """Test extracting patches from a stack of images.""" img = torch.zeros((2, 28, 28)) @@ -34,6 +37,7 @@ def test_extract_square_patches_single(): assert torch.allclose(patches[0, 1], expected_image_1, atol=1e-6) +@device_test def test_extract_square_patches_batched(): """Test batched particle extraction from single image.""" img = torch.zeros((28, 28)) diff --git a/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py b/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py index faf951f..fa831c3 100644 --- a/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py +++ b/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py @@ -2,8 +2,10 @@ from libtilt.patch_extraction.subpixel_cubic_patch_extraction import extract_cubes, \ _extract_cubic_patches_from_single_3d_image +from libtilt.pytest_utils import device_test +@device_test def test_single_cubic_patch_from_single_image(): """Test cubic patch extraction from single 3D image.""" img = torch.zeros((28, 28, 28)) @@ -17,6 +19,7 @@ def test_single_cubic_patch_from_single_image(): assert torch.allclose(patches, expected_image, atol=1e-6) +@device_test def test_extract_cubic_patches(): """Test extracting cubic patches from a 3D image.""" img = torch.zeros((28, 28, 28)) From a66337955d3ff8f2662df391d969dc5b12b9eb28 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:57:12 +0200 Subject: [PATCH 16/22] add projection --- src/libtilt/projection/tests/test_project_fourier.py | 3 +++ src/libtilt/projection/tests/test_project_real.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/src/libtilt/projection/tests/test_project_fourier.py b/src/libtilt/projection/tests/test_project_fourier.py index 144e56e..917b6b3 100644 --- a/src/libtilt/projection/tests/test_project_fourier.py +++ b/src/libtilt/projection/tests/test_project_fourier.py @@ -2,8 +2,10 @@ from scipy.spatial.transform import Rotation as R from libtilt.projection.project_fourier import project_fourier +from libtilt.pytest_utils import device_test +@device_test def test_project_no_rotation(): volume = torch.zeros((10, 10, 10)) volume[5, 5, 5] = 1 @@ -15,6 +17,7 @@ def test_project_no_rotation(): assert torch.allclose(projection, expected) +@device_test def test_project_with_rotation(): volume = torch.zeros((10, 10, 10)) volume[5, 5, 5] = 1 diff --git a/src/libtilt/projection/tests/test_project_real.py b/src/libtilt/projection/tests/test_project_real.py index 1bfe645..304809c 100644 --- a/src/libtilt/projection/tests/test_project_real.py +++ b/src/libtilt/projection/tests/test_project_real.py @@ -1,8 +1,10 @@ import torch from libtilt.projection.project_real import project_image_real, project_volume_real +from libtilt.pytest_utils import device_test +@device_test def test_real_space_projection_3d(): volume_shape = (2, 10, 10) volume = torch.arange(2*10*10).reshape(volume_shape).float() @@ -11,6 +13,7 @@ def test_real_space_projection_3d(): assert torch.allclose(projection.squeeze(), torch.sum(volume, dim=0)) +@device_test def test_real_space_projection_2d(): image_shape = (8, 12) image = torch.arange(8 * 12).reshape(image_shape).float() From 671f06bdd2fbe7f95959d4492b657a7ed17f5db8 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 17:59:57 +0200 Subject: [PATCH 17/22] add rotational averaging + fix for tensor split to cpu --- .../rotational_averaging/rotational_average_dft.py | 9 ++++++--- .../tests/test_rotational_average_dft.py | 6 ++++++ 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/src/libtilt/rotational_averaging/rotational_average_dft.py b/src/libtilt/rotational_averaging/rotational_average_dft.py index 1df5be4..1a6ea40 100644 --- a/src/libtilt/rotational_averaging/rotational_average_dft.py +++ b/src/libtilt/rotational_averaging/rotational_average_dft.py @@ -95,7 +95,8 @@ def _find_shell_indices_1d( """Find indices which index to give values either side of split points.""" sorted, sort_idx = torch.sort(values, descending=False) split_idx = torch.searchsorted(sorted, split_values) - return torch.tensor_split(sort_idx, split_idx) + # tensor_split requires the split_idx to live on cpu + return torch.tensor_split(sort_idx, split_idx.to('cpu')) def _find_shell_indices_2d( @@ -107,7 +108,8 @@ def _find_shell_indices_2d( idx_2d = einops.rearrange(idx_2d, 'h w idx -> (h w) idx') sorted, sort_idx = torch.sort(values, descending=False) split_idx = torch.searchsorted(sorted, split_values) - return torch.tensor_split(idx_2d[sort_idx], split_idx) + # tensor_split requires the split_idx to live on cpu + return torch.tensor_split(idx_2d[sort_idx], split_idx.to('cpu')) def _find_shell_indices_3d( @@ -119,7 +121,8 @@ def _find_shell_indices_3d( idx_3d = einops.rearrange(idx_3d, 'd h w idx -> (d h w) idx') sorted, sort_idx = torch.sort(values, descending=False) split_idx = torch.searchsorted(sorted, split_values) - return torch.tensor_split(idx_3d[sort_idx], split_idx) + # tensor_split requires the split_idx to live on cpu + return torch.tensor_split(idx_3d[sort_idx], split_idx.to('cpu')) def _split_into_frequency_bins_2d( diff --git a/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py b/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py index 28209e2..861be4d 100644 --- a/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py +++ b/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py @@ -8,8 +8,10 @@ rotational_average_dft_2d, rotational_average_dft_3d, ) +from libtilt.pytest_utils import device_test +@device_test def test_split_into_frequency_bins_2d(): # no rfft, fftshifted frequencies = fftfreq_grid( @@ -55,6 +57,7 @@ def test_split_into_frequency_bins_2d(): assert torch.allclose(shells[0], frequencies[:, 14, 0].reshape((2, 1))) +@device_test def test_split_into_shells_3d(): # no rfft, fftshifted frequencies = fftfreq_grid( @@ -100,6 +103,7 @@ def test_split_into_shells_3d(): assert torch.allclose(shells[0], frequencies[:, 14, 14, 0].reshape((2, 1))) +@device_test def test_rotational_average_dft_2d(): # single image dft = fftfreq_grid(image_shape=(28, 28), rfft=False, fftshift=True, norm=True) @@ -136,6 +140,7 @@ def test_rotational_average_dft_2d(): assert rotational_average.shape == expected_shape +@device_test def test_rotational_average_return_2d(): # no batching image = fftfreq_grid(image_shape=(28, 28), rfft=False, fftshift=True, norm=True) @@ -159,6 +164,7 @@ def test_rotational_average_return_2d(): atol=1e-2) +@device_test def test_rotational_average_dft_3d(): # single image dft = fftfreq_grid(image_shape=(28, 28, 28), rfft=False, fftshift=True, norm=True) From 6f5c84b5284873a817f21b355506c10f0e33df69 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 18:28:11 +0200 Subject: [PATCH 18/22] feat: add soft edge device tests --- .pre-commit-config.yaml | 16 +++++++------- pyproject.toml | 25 +++++++++++++++------- src/libtilt/shapes/soft_edge.py | 19 ++++++++-------- src/libtilt/shapes/tests/test_soft_edge.py | 14 +++++++++--- 4 files changed, 45 insertions(+), 29 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 672d9a6..03c4566 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,42 +7,42 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/compilerla/conventional-pre-commit - rev: v1.3.0 + rev: v3.2.0 hooks: - id: conventional-pre-commit stages: [commit-msg] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/myint/autoflake - rev: v1.4 + rev: v2.3.1 hooks: - id: autoflake args: ["--in-place", "--remove-all-unused-imports"] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.34.0 + rev: v3.16.0 hooks: - id: pyupgrade args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.4.2 hooks: - id: black - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 7.1.0 hooks: - id: flake8 additional_dependencies: @@ -52,7 +52,7 @@ repos: - flake8-typing-imports - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.961 + rev: v1.10.0 hooks: - id: mypy files: "^source_spacing/" diff --git a/pyproject.toml b/pyproject.toml index 12d3c06..9da5163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,16 +94,25 @@ src_paths = ["src/libtilt", "tests"] [tool.flake8] exclude = "docs,.eggs,examples,_version.py" max-line-length = 88 -ignore = "E203" +ignore = ["E203",] min-python-version = "3.8.0" docstring-convention = "all" # use numpy convention, while allowing D417 -extend-ignore = """ -E203 # whitespace before ':' -D107,D203,D212,D213,D402,D413,D415,D416 # numpy -D100 # missing docstring in public module -D401 # imperative mood -W503 # line break before binary operator -""" +extend-ignore = [ + "E203", # whitespace before ':' + "D107", + "D203", + "D212", + "D213", + "D402", + "D413", + "D415", + "D416", # numpy, + "E203", + "D100", # missing docstring in public module + "D401", # imperative mood + "W503", # line break before binary operator + "D103", # skip missin docstrings TODO update docstrings and remove +] per-file-ignores = [ "tests/*:D", ] diff --git a/src/libtilt/shapes/soft_edge.py b/src/libtilt/shapes/soft_edge.py index ef7d456..5b42063 100644 --- a/src/libtilt/shapes/soft_edge.py +++ b/src/libtilt/shapes/soft_edge.py @@ -8,7 +8,8 @@ def _add_soft_edge_single_binary_image( ) -> torch.FloatTensor: if smoothing_radius == 0: return image.float() - distances = ndi.distance_transform_edt(torch.logical_not(image)) + # move explicitly to cpu for scipy + distances = ndi.distance_transform_edt(torch.logical_not(image).to("cpu")) distances = torch.as_tensor(distances, device=image.device).float() idx = torch.logical_and(distances > 0, distances <= smoothing_radius) output = torch.clone(image).float() @@ -19,7 +20,7 @@ def _add_soft_edge_single_binary_image( def add_soft_edge_2d( image: torch.Tensor, smoothing_radius: torch.Tensor | float ) -> torch.Tensor: - image_packed, ps = einops.pack([image], '* h w') + image_packed, ps = einops.pack([image], "* h w") b = image_packed.shape[0] if isinstance(smoothing_radius, float | int): @@ -30,18 +31,17 @@ def add_soft_edge_2d( results = [ _add_soft_edge_single_binary_image(_image, smoothing_radius=_smoothing_radius) - for _image, _smoothing_radius - in zip(image_packed, smoothing_radius) + for _image, _smoothing_radius in zip(image_packed, smoothing_radius) ] results = torch.stack(results, dim=0) - [results] = einops.unpack(results, pattern='* h w', packed_shapes=ps) + [results] = einops.unpack(results, pattern="* h w", packed_shapes=ps) return results def add_soft_edge_3d( image: torch.Tensor, smoothing_radius: torch.Tensor | float ) -> torch.Tensor: - image_packed, ps = einops.pack([image], '* d h w') + image_packed, ps = einops.pack([image], "* d h w") b = image_packed.shape[0] if isinstance(smoothing_radius, float | int): smoothing_radius = torch.as_tensor( @@ -50,9 +50,8 @@ def add_soft_edge_3d( smoothing_radius = torch.broadcast_to(smoothing_radius, (b,)) results = [ _add_soft_edge_single_binary_image(_image, smoothing_radius=_smoothing_radius) - for _image, _smoothing_radius - in zip(image_packed, smoothing_radius) + for _image, _smoothing_radius in zip(image_packed, smoothing_radius) ] results = torch.stack(results, dim=0) - [results] = einops.unpack(results, pattern='* d h w', packed_shapes=ps) - return results \ No newline at end of file + [results] = einops.unpack(results, pattern="* d h w", packed_shapes=ps) + return results diff --git a/src/libtilt/shapes/tests/test_soft_edge.py b/src/libtilt/shapes/tests/test_soft_edge.py index 261d326..e105812 100644 --- a/src/libtilt/shapes/tests/test_soft_edge.py +++ b/src/libtilt/shapes/tests/test_soft_edge.py @@ -1,20 +1,28 @@ import torch -from libtilt.shapes.soft_edge import _add_soft_edge_single_binary_image, add_soft_edge_2d +from libtilt.pytest_utils import device_test +from libtilt.shapes.soft_edge import ( + _add_soft_edge_single_binary_image, + add_soft_edge_2d, +) +@device_test def test_add_soft_edge_single_binary_image(): dim_length = 5 smoothing_radius = 4 image = torch.zeros(size=(dim_length, 1)) image[0, 0] = 1 - smoothed = _add_soft_edge_single_binary_image(image, smoothing_radius=smoothing_radius) + smoothed = _add_soft_edge_single_binary_image( + image, smoothing_radius=smoothing_radius + ) # cosine falloff, 1 to zero over smooothing radius expected = torch.cos((torch.pi / 2) * torch.arange(5) / smoothing_radius) expected = expected.view((dim_length, 1)) assert torch.allclose(smoothed, expected) +@device_test def test_add_soft_edge_2d(): # single image image = torch.zeros(size=(5, 5)) @@ -29,4 +37,4 @@ def test_add_soft_edge_2d(): images[..., 0, 0] = 1 results = add_soft_edge_2d(images, smoothing_radius=smoothing_radius) assert results.shape == images.shape - assert torch.allclose(results[0], results[1]) \ No newline at end of file + assert torch.allclose(results[0], results[1]) From 1bc31036ed9efddc1d193e79b9648780ef6646ca Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 18:29:43 +0200 Subject: [PATCH 19/22] feat: add device tests for dft shifts --- src/libtilt/shift/tests/test_phase_shift_2d.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/src/libtilt/shift/tests/test_phase_shift_2d.py b/src/libtilt/shift/tests/test_phase_shift_2d.py index c5db0f2..8b5577c 100644 --- a/src/libtilt/shift/tests/test_phase_shift_2d.py +++ b/src/libtilt/shift/tests/test_phase_shift_2d.py @@ -1,9 +1,11 @@ import torch -from libtilt.shift.shift_image import shift_2d +from libtilt.pytest_utils import device_test from libtilt.shift.phase_shift_dft import get_phase_shifts_2d +from libtilt.shift.shift_image import shift_2d +@device_test def test_get_phase_shifts_2d_full_fft(): shifts = torch.zeros(size=(1, 2)) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=False) @@ -11,11 +13,13 @@ def test_get_phase_shifts_2d_full_fft(): shifts = torch.tensor([[1, 2]]) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=False) - expected = torch.tensor([[[1 + 0.0000e+00j, 1 + 1.7485e-07j], - [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]]) + expected = torch.tensor( + [[[1 + 0.0000e00j, 1 + 1.7485e-07j], [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]] + ) assert torch.allclose(phase_shifts, expected) +@device_test def test_get_phase_shifts_2d_rfft(): shifts = torch.zeros(size=(1, 2)) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=True) @@ -25,11 +29,13 @@ def test_get_phase_shifts_2d_rfft(): shifts = torch.tensor([[1, 2]]) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=False) - expected = torch.tensor([[[1 + 0.0000e+00j, 1 + 1.7485e-07j], - [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]]) + expected = torch.tensor( + [[[1 + 0.0000e00j, 1 + 1.7485e-07j], [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]] + ) assert torch.allclose(phase_shifts, expected) +@device_test def test_phase_shift_images_2d(): image = torch.zeros((4, 4)) image[2, 2] = 1 From 4f99bc7b0125a35fdbb6c1c69621b5cb3f83eb73 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 18:37:50 +0200 Subject: [PATCH 20/22] docs: update readme with dev install info --- README.md | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/README.md b/README.md index 43d6d7e..12658f4 100644 --- a/README.md +++ b/README.md @@ -7,3 +7,26 @@ [![codecov](https://codecov.io/gh/alisterburt/libtilt/branch/main/graph/badge.svg)](https://codecov.io/gh/alisterburt/libtilt) Image processing for cryo-electron tomography in PyTorch + + +## For developers + +We advise to fork the repository and make a local clone of your fork. After setting +up an environment (e.g. via miniconda), the development version can be installed with +(`-e` runs an editable install): + +```commandline +python -m pip install -e '.[dev,test]' +``` + +Then ready pre-commits for automated code checks and styling: + +```commandline +pre-commit install +``` + +Before making any pull request please make sure all the unittests pass: + +```commandline +python -m pytest +``` From 3e497098d61aaf9f30cb9ff1e0302394db3c637c Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 19:09:44 +0200 Subject: [PATCH 21/22] feat: scan decorator for available devices --- src/libtilt/pytest_utils.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/libtilt/pytest_utils.py b/src/libtilt/pytest_utils.py index f72f225..8a4c477 100644 --- a/src/libtilt/pytest_utils.py +++ b/src/libtilt/pytest_utils.py @@ -1,11 +1,19 @@ -import torch import functools +import torch + +AVAILABLE_DEVICES = ["cpu"] +if torch.backends.mps.is_available(): + AVAILABLE_DEVICES.append("mps") +if torch.cuda.is_available(): + AVAILABLE_DEVICES.append("cuda") + def device_test(test_func): @functools.wraps(test_func) def run_devices(*args, **kwargs): - for device in ('cpu', 'cuda'): + for device in AVAILABLE_DEVICES: with torch.device(device): test_func(*args, **kwargs) + return run_devices From 4cabb72137ef3725a335eaa625343c93b5328682 Mon Sep 17 00:00:00 2001 From: McHaillet Date: Mon, 17 Jun 2024 19:13:33 +0200 Subject: [PATCH 22/22] docs: update readme about cpu/gpu testing --- README.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/README.md b/README.md index 12658f4..62a9a7f 100644 --- a/README.md +++ b/README.md @@ -30,3 +30,8 @@ Before making any pull request please make sure all the unittests pass: ```commandline python -m pytest ``` + +The libtilt `device_test` decorator should be added to each test and tries to run +the code on GPU. This can either be via CUDA or via the 'mps' backend for M1 chips. +Please keep in mind that without a GPU available on your system, the code is not +explicitly tested to run on GPU.