diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 672d9a6..03c4566 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -7,42 +7,42 @@ default_install_hook_types: [pre-commit, commit-msg] repos: - repo: https://github.com/compilerla/conventional-pre-commit - rev: v1.3.0 + rev: v3.2.0 hooks: - id: conventional-pre-commit stages: [commit-msg] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.3.0 + rev: v4.6.0 hooks: - id: check-docstring-first - id: end-of-file-fixer - id: trailing-whitespace - repo: https://github.com/myint/autoflake - rev: v1.4 + rev: v2.3.1 hooks: - id: autoflake args: ["--in-place", "--remove-all-unused-imports"] - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.13.2 hooks: - id: isort - repo: https://github.com/asottile/pyupgrade - rev: v2.34.0 + rev: v3.16.0 hooks: - id: pyupgrade args: [--py38-plus, --keep-runtime-typing] - repo: https://github.com/psf/black - rev: 22.3.0 + rev: 24.4.2 hooks: - id: black - repo: https://github.com/PyCQA/flake8 - rev: 4.0.1 + rev: 7.1.0 hooks: - id: flake8 additional_dependencies: @@ -52,7 +52,7 @@ repos: - flake8-typing-imports - repo: https://github.com/pre-commit/mirrors-mypy - rev: v0.961 + rev: v1.10.0 hooks: - id: mypy files: "^source_spacing/" diff --git a/README.md b/README.md index 43d6d7e..62a9a7f 100644 --- a/README.md +++ b/README.md @@ -7,3 +7,31 @@ [![codecov](https://codecov.io/gh/alisterburt/libtilt/branch/main/graph/badge.svg)](https://codecov.io/gh/alisterburt/libtilt) Image processing for cryo-electron tomography in PyTorch + + +## For developers + +We advise to fork the repository and make a local clone of your fork. After setting +up an environment (e.g. via miniconda), the development version can be installed with +(`-e` runs an editable install): + +```commandline +python -m pip install -e '.[dev,test]' +``` + +Then ready pre-commits for automated code checks and styling: + +```commandline +pre-commit install +``` + +Before making any pull request please make sure all the unittests pass: + +```commandline +python -m pytest +``` + +The libtilt `device_test` decorator should be added to each test and tries to run +the code on GPU. This can either be via CUDA or via the 'mps' backend for M1 chips. +Please keep in mind that without a GPU available on your system, the code is not +explicitly tested to run on GPU. diff --git a/pyproject.toml b/pyproject.toml index 12d3c06..9da5163 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -94,16 +94,25 @@ src_paths = ["src/libtilt", "tests"] [tool.flake8] exclude = "docs,.eggs,examples,_version.py" max-line-length = 88 -ignore = "E203" +ignore = ["E203",] min-python-version = "3.8.0" docstring-convention = "all" # use numpy convention, while allowing D417 -extend-ignore = """ -E203 # whitespace before ':' -D107,D203,D212,D213,D402,D413,D415,D416 # numpy -D100 # missing docstring in public module -D401 # imperative mood -W503 # line break before binary operator -""" +extend-ignore = [ + "E203", # whitespace before ':' + "D107", + "D203", + "D212", + "D213", + "D402", + "D413", + "D415", + "D416", # numpy, + "E203", + "D100", # missing docstring in public module + "D401", # imperative mood + "W503", # line break before binary operator + "D103", # skip missin docstrings TODO update docstrings and remove +] per-file-ignores = [ "tests/*:D", ] diff --git a/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py b/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py index f082fd6..361c264 100644 --- a/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py +++ b/src/libtilt/_integration_tests/test_fourier_slice_extraction_insertion.py @@ -5,8 +5,10 @@ insert_into_dft_3d from libtilt.grids import rotated_central_slice_grid from libtilt.fft_utils import fftfreq_to_dft_coordinates +from libtilt.pytest_utils import device_test +@device_test def test_fourier_slice_extraction_insertion_cycle_no_rotation(): sidelength = 64 volume = torch.zeros((sidelength, sidelength, sidelength), dtype=torch.complex64) @@ -36,6 +38,7 @@ def test_fourier_slice_extraction_insertion_cycle_no_rotation(): assert torch.allclose(input_slice, output_slice) +@device_test def test_fourier_slice_extraction_insertion_cycle_with_rotation(): sidelength = 64 diff --git a/src/libtilt/_tests/test_coordinate_tils.py b/src/libtilt/_tests/test_coordinate_tils.py index 9ef74d1..5936c9d 100644 --- a/src/libtilt/_tests/test_coordinate_tils.py +++ b/src/libtilt/_tests/test_coordinate_tils.py @@ -9,8 +9,10 @@ homogenise_coordinates, ) from libtilt.grids.coordinate_grid import coordinate_grid +from libtilt.pytest_utils import device_test +@device_test def test_array_coordinates_to_grid_sample_coordinates_nd(): array_shape = z, y, x = (4, 8, 12) array_coordinates = einops.rearrange(torch.tensor(np.indices(array_shape)), @@ -27,6 +29,7 @@ def test_array_coordinates_to_grid_sample_coordinates_nd(): assert torch.allclose(grid_sample_coordinates[:, 0, 0, 2], expected_z) +@device_test def test_grid_sample_coordinates_to_array_coordinates_nd(): array_shape = (4, 8, 12) expected_array_coordinates = einops.rearrange( @@ -41,6 +44,7 @@ def test_grid_sample_coordinates_to_array_coordinates_nd(): assert torch.allclose(array_coordinates, expected_array_coordinates) +@device_test def test_add_implied_coordinate_from_dimension(): batch_of_stacked_2d_coords = torch.zeros(size=(1, 5, 2)) # (b, stack, 2) result = add_positional_coordinate(batch_of_stacked_2d_coords, dim=1) @@ -49,6 +53,7 @@ def test_add_implied_coordinate_from_dimension(): assert torch.allclose(result, expected) +@device_test def test_add_implied_coordinate_from_dimension_prepend(): batch_of_stacked_2d_coords = torch.zeros(size=(1, 5, 2)) # (b, stack, 2) result = add_positional_coordinate(batch_of_stacked_2d_coords, dim=1, @@ -58,6 +63,7 @@ def test_add_implied_coordinate_from_dimension_prepend(): assert torch.allclose(result, expected) +@device_test def test_get_grid_coordinates(): coords = coordinate_grid(image_shape=(3, 2)) assert coords.shape == (3, 2, 2) @@ -74,6 +80,7 @@ def test_get_grid_coordinates(): assert torch.allclose(coords, expected) +@device_test def test_homogenise_coordinates(): coords = torch.rand(size=(2, 3)) homogenised = homogenise_coordinates(coords) diff --git a/src/libtilt/_tests/test_fft_utils.py b/src/libtilt/_tests/test_fft_utils.py index 5daa3d2..d4418ed 100644 --- a/src/libtilt/_tests/test_fft_utils.py +++ b/src/libtilt/_tests/test_fft_utils.py @@ -15,8 +15,14 @@ fftfreq_to_dft_coordinates, ) from libtilt.grids.fftfreq_grid import _construct_fftfreq_grid_2d +from libtilt.pytest_utils import device_test +# seed the random number generator to ensure tests are consistent +torch.manual_seed(0) + + +@device_test def test_rfft_shape_from_signal_shape(): # even signal_shape = (2, 4, 8, 16) @@ -31,6 +37,7 @@ def test_rfft_shape_from_signal_shape(): assert rfft.shape == rfft_shape(signal_shape) +@device_test def test_construct_fftfreq_grid_2d(): image_shape = (10, 30) # no rfft @@ -46,6 +53,7 @@ def test_construct_fftfreq_grid_2d(): assert torch.allclose(grid[0, :, 1], torch.fft.rfftfreq(30)) +@device_test def test_rfft_to_symmetrised_dft_2d(): image = torch.rand((10, 10)) fft = torch.fft.fftshift(torch.fft.fftn(image, dim=(-2, -1)), dim=(-2, -1)) @@ -54,6 +62,7 @@ def test_rfft_to_symmetrised_dft_2d(): assert torch.allclose(fft, symmetrised_dft[:-1, :-1], atol=1e-7) +@device_test def test_rfft_to_symmetrised_dft_2d_batched(): image = torch.rand((2, 10, 10)) # (b, h, w) fft = torch.fft.fftshift(torch.fft.fftn(image, dim=(-2, -1)), dim=(-2, -1)) @@ -62,13 +71,13 @@ def test_rfft_to_symmetrised_dft_2d_batched(): assert torch.allclose(fft, symmetrised_dft[..., :-1, :-1], atol=1e-7) +@device_test def test_rfft_to_symmetrised_dft_3d(): image = torch.rand((10, 10, 10)) fft_dims = (-3, -2, -1) fft = torch.fft.fftshift(torch.fft.fftn(image, dim=fft_dims), dim=fft_dims) rfft = torch.fft.rfftn(image, dim=(-3, -2, -1)) symmetrised_dft = _rfft_to_symmetrised_dft_3d(rfft) - np.array(fft - symmetrised_dft[:-1, :-1, :-1]) assert torch.allclose(fft, symmetrised_dft[:-1, :-1, :-1], atol=1e-5) @@ -76,6 +85,7 @@ def test_rfft_to_symmetrised_dft_3d(): "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_2d(inplace: bool): image = torch.rand((10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -90,6 +100,7 @@ def test_symmetrised_dft_to_dft_2d(inplace: bool): "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): image = torch.rand((2, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -104,6 +115,7 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_rfft_2d(inplace: bool): image = torch.rand((10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -117,6 +129,7 @@ def test_symmetrised_dft_to_rfft_2d(inplace: bool): "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): image = torch.rand((2, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-2, -1)) @@ -131,6 +144,7 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool): "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_3d(inplace: bool): image = torch.rand((10, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-3, -2, -1)) @@ -147,6 +161,7 @@ def test_symmetrised_dft_to_dft_3d(inplace: bool): "inplace", [(True,), (False,)] ) +@device_test def test_symmetrised_dft_to_dft_3d_batched(inplace: bool): image = torch.rand((2, 10, 10, 10)) rfft = torch.fft.rfftn(image, dim=(-3, -2, -1)) @@ -161,21 +176,25 @@ def test_symmetrised_dft_to_dft_3d_batched(inplace: bool): @pytest.mark.parametrize( "fftshifted, rfft, input, expected", [ - (False, False, (5, 5, 5), torch.tensor([0., 0., 0.])), - (False, True, (5, 5, 5), torch.tensor([0., 0., 0.])), - (True, False, (5, 5, 5), torch.tensor([2., 2., 2.])), - (True, True, (5, 5, 5), torch.tensor([2., 2., 0.])), - (False, False, (4, 4, 4), torch.tensor([0., 0., 0.])), - (False, True, (4, 4, 4), torch.tensor([0., 0., 0.])), - (True, False, (4, 4, 4), torch.tensor([2., 2., 2.])), - (True, True, (4, 4, 4), torch.tensor([2., 2., 0.])), + (False, False, (5, 5, 5), [0., 0., 0.]), + (False, True, (5, 5, 5), [0., 0., 0.]), + (True, False, (5, 5, 5), [2., 2., 2.]), + (True, True, (5, 5, 5), [2., 2., 0.]), + (False, False, (4, 4, 4), [0., 0., 0.]), + (False, True, (4, 4, 4), [0., 0., 0.]), + (True, False, (4, 4, 4), [2., 2., 2.]), + (True, True, (4, 4, 4), [2., 2., 0.]), ], ) +@device_test def test_fft_center(fftshifted, rfft, input, expected): result = dft_center(input, fftshifted=fftshifted, rfft=rfft) - assert torch.allclose(result, expected.long()) + assert torch.allclose(result, torch.tensor( + expected, device=result.device, dtype=result.dtype + )) +@device_test def test_fftfreq_to_spatial_frequency(): fftfreq = torch.fft.fftfreq(10) k = fftfreq_to_spatial_frequency(fftfreq, spacing=2) @@ -187,6 +206,7 @@ def test_fftfreq_to_spatial_frequency(): assert torch.allclose(k, expected) +@device_test def test_spatial_frequency_to_fftfreq(): k = torch.fft.fftfreq(10, d=2) fftfreq = spatial_frequency_to_fftfreq(k, spacing=2) @@ -199,6 +219,7 @@ def test_spatial_frequency_to_fftfreq(): assert torch.allclose(fftfreq, expected) +@device_test def test_fftfreq_to_dft_coords(): from libtilt.grids import fftfreq_grid, coordinate_grid diff --git a/src/libtilt/_tests/test_transformations.py b/src/libtilt/_tests/test_transformations.py index 6e7b244..a3c6c2b 100644 --- a/src/libtilt/_tests/test_transformations.py +++ b/src/libtilt/_tests/test_transformations.py @@ -1,8 +1,10 @@ import torch from libtilt.transformations import Rx, Ry, Rz, T +from libtilt.pytest_utils import device_test +@device_test def test_rotation_around_x(): """Rotation of y around x should become z.""" R = Rx(90) @@ -16,6 +18,7 @@ def test_rotation_around_x(): assert torch.allclose(R @ v, expected, atol=1e-6) +@device_test def test_rotation_around_y(): """Rotation of z around y should be x""" R = Ry(90) @@ -29,6 +32,7 @@ def test_rotation_around_y(): assert torch.allclose(R @ v, expected, atol=1e-6) +@device_test def test_rotation_around_z(): """Rotation of x around z should give y.""" R = Rz(90) @@ -42,6 +46,7 @@ def test_rotation_around_z(): assert torch.allclose(R @ v, expected, atol=1e-6) +@device_test def test_translation(): """Translations""" M = T([1, 1, 1]) diff --git a/src/libtilt/alignment/find_shift.py b/src/libtilt/alignment/find_shift.py index bb2089d..d9af489 100644 --- a/src/libtilt/alignment/find_shift.py +++ b/src/libtilt/alignment/find_shift.py @@ -1,5 +1,4 @@ import torch -import numpy as np import torch.nn.functional as F import einops @@ -47,7 +46,7 @@ def find_image_shift( normalize=True ) maximum_idx = torch.tensor( # explicitly put tensor on CPU in case input is on GPU - np.unravel_index(correlation.argmax().cpu(), shape=image_a.shape), + torch.unravel_index(correlation.argmax().cpu(), shape=image_a.shape), device=image_a.device ) shift = maximum_idx - center @@ -76,7 +75,7 @@ def find_image_shift( upsampled.shape, rfft=False, fftshifted=True, device=image_a.device ) upsampled_shift = torch.tensor( - np.unravel_index(upsampled.argmax().cpu(), shape=upsampled.shape), + torch.unravel_index(upsampled.argmax().cpu(), shape=upsampled.shape), device=image_a.device ) - upsampled_center full_shift = shift + upsampled_shift / upsampling_factor diff --git a/src/libtilt/alignment/tests/test_find_shift.py b/src/libtilt/alignment/tests/test_find_shift.py index dce7484..91ca3e7 100644 --- a/src/libtilt/alignment/tests/test_find_shift.py +++ b/src/libtilt/alignment/tests/test_find_shift.py @@ -2,8 +2,10 @@ import pytest from libtilt.alignment import find_image_shift +from libtilt.pytest_utils import device_test +@device_test def test_find_image_shift(): a = torch.zeros((4, 4)) a[1, 1] = 1 diff --git a/src/libtilt/atomic_models/tests/test_coordinates_to_image.py b/src/libtilt/atomic_models/tests/test_coordinates_to_image.py index 47174c2..cb95939 100644 --- a/src/libtilt/atomic_models/tests/test_coordinates_to_image.py +++ b/src/libtilt/atomic_models/tests/test_coordinates_to_image.py @@ -3,8 +3,10 @@ from libtilt.atomic_models.coordinates_to_image import ( coordinates_to_image_2d, coordinates_to_image_3d ) +from libtilt.pytest_utils import device_test +@device_test def test_coordinates_to_image_2d(): coordinates = torch.as_tensor([14, 14]) image = coordinates_to_image_2d(coordinates=coordinates, image_shape=(28, 28)) @@ -14,6 +16,7 @@ def test_coordinates_to_image_2d(): assert torch.allclose(image, expected) +@device_test def test_coordinates_to_image_3d(): coordinates = torch.as_tensor([14, 14, 14]) image = coordinates_to_image_3d(coordinates=coordinates, image_shape=(28, 28, 28)) diff --git a/src/libtilt/correlation/tests/test_correlate.py b/src/libtilt/correlation/tests/test_correlate.py index 867bb6b..b45993e 100644 --- a/src/libtilt/correlation/tests/test_correlate.py +++ b/src/libtilt/correlation/tests/test_correlate.py @@ -1,17 +1,18 @@ import torch -import numpy as np from libtilt.correlation import correlate_2d, correlate_dft_2d from libtilt.fft_utils import fftshift_2d +from libtilt.pytest_utils import device_test +@device_test def test_correlate_2d(): a = torch.zeros((10, 10)) a[5, 5] = 1 b = torch.zeros((10, 10)) b[6, 6] = 1 cross_correlation = correlate_2d(a, b, normalize=True) - peak_position = np.unravel_index( + peak_position = torch.unravel_index( indices=torch.argmax(cross_correlation), shape=cross_correlation.shape ) shift = torch.as_tensor(peak_position) - torch.tensor([5, 5]) @@ -20,6 +21,7 @@ def test_correlate_2d(): assert torch.allclose(cross_correlation[peak_position], torch.tensor([1.])) +@device_test def test_correlate_dft_2d(): a = torch.zeros((10, 10)) a[5, 5] = 1 diff --git a/src/libtilt/ctf/tests/test_ctf.py b/src/libtilt/ctf/tests/test_ctf.py index 8208c87..cee550f 100644 --- a/src/libtilt/ctf/tests/test_ctf.py +++ b/src/libtilt/ctf/tests/test_ctf.py @@ -3,8 +3,10 @@ from libtilt.ctf.ctf_1d import calculate_ctf as calculate_ctf_1d from libtilt.ctf.relativistic_wavelength import \ calculate_relativistic_electron_wavelength +from libtilt.pytest_utils import device_test +@device_test def test_1d_ctf_single(): result = calculate_ctf_1d( defocus=1.5, @@ -34,6 +36,7 @@ def test_1d_ctf_single(): assert torch.allclose(result[0], expected, atol=1e-4) +@device_test def test_1d_ctf_batch_defocus(): result = calculate_ctf_1d( defocus=[1.5, 2.5], @@ -56,6 +59,7 @@ def test_1d_ctf_batch_defocus(): assert torch.allclose(result, expected, atol=1e-4) +@device_test def test_calculate_relativistic_electron_wavelength(): """Check function matches expected value from literature. diff --git a/src/libtilt/filters/bandpass.py b/src/libtilt/filters/bandpass.py index d0aca67..d75283a 100644 --- a/src/libtilt/filters/bandpass.py +++ b/src/libtilt/filters/bandpass.py @@ -56,6 +56,7 @@ def bandpass_dft( falloff=falloff, image_shape=image_shape, rfft=rfft, - fftshift=fftshifted + fftshift=fftshifted, + device=dft.device ) return dft * filter diff --git a/src/libtilt/filters/bfactors.py b/src/libtilt/filters/bfactors.py index 6440cb2..0f877eb 100644 --- a/src/libtilt/filters/bfactors.py +++ b/src/libtilt/filters/bfactors.py @@ -50,6 +50,7 @@ def bfactor_dft( image_shape=image_shape, pixel_size=pixel_size, rfft=rfft, - fftshift=fftshifted + fftshift=fftshifted, + device=dft.device ) return dft * b_env diff --git a/src/libtilt/filters/tests/test_b.py b/src/libtilt/filters/tests/test_b.py index 77c730e..8c61875 100644 --- a/src/libtilt/filters/tests/test_b.py +++ b/src/libtilt/filters/tests/test_b.py @@ -1,7 +1,10 @@ import torch from libtilt.filters import bfactor_2d +from libtilt.pytest_utils import device_test + +@device_test def test_bfactor_2d(): # Generate an image image = torch.zeros((4,4)) diff --git a/src/libtilt/filters/tests/test_bandpass.py b/src/libtilt/filters/tests/test_bandpass.py index de16207..1e92c7e 100644 --- a/src/libtilt/filters/tests/test_bandpass.py +++ b/src/libtilt/filters/tests/test_bandpass.py @@ -1,9 +1,18 @@ import torch from libtilt.filters.filters import bandpass_filter +from libtilt.pytest_utils import device_test +@device_test def test_bandpass_filter(): + freqs = torch.fft.fftfreq(20) + in_band_idx = torch.logical_and(freqs >= 0.2, freqs <= 0.4) + lower_idx = torch.logical_and(freqs >= 0.1, freqs <= 0.2) + upper_idx = freqs > 0.4 + lower_falloff = torch.cos((torch.pi / 2) * ((freqs[lower_idx] - 0.2) / 0.1)) + upper_falloff = torch.cos((torch.pi / 2) * ((freqs[upper_idx] - 0.4) / 0.1)) + filter = bandpass_filter( low=0.2, high=0.4, @@ -11,13 +20,9 @@ def test_bandpass_filter(): image_shape=(20, 1), rfft=False, fftshift=False, + device=freqs.device ) - freqs = torch.fft.fftfreq(20) - in_band_idx = torch.logical_and(freqs >= 0.2, freqs <= 0.4) - lower_idx = torch.logical_and(freqs >= 0.1, freqs <= 0.2) - upper_idx = freqs > 0.4 - lower_falloff = torch.cos((torch.pi / 2) * ((freqs[lower_idx] - 0.2) / 0.1)) - upper_falloff = torch.cos((torch.pi / 2) * ((freqs[upper_idx] - 0.4) / 0.1)) + assert torch.all(filter[in_band_idx] == 1) assert torch.allclose(filter[lower_idx], lower_falloff.view((-1, 1)), atol=1e-6) assert torch.allclose(filter[upper_idx], upper_falloff.view((-1, 1)), atol=1e-6) diff --git a/src/libtilt/fsc/_tests/test_fsc.py b/src/libtilt/fsc/_tests/test_fsc.py index dcbd433..54ae86e 100644 --- a/src/libtilt/fsc/_tests/test_fsc.py +++ b/src/libtilt/fsc/_tests/test_fsc.py @@ -1,20 +1,24 @@ import torch from libtilt.fsc.fsc import fsc +from libtilt.pytest_utils import device_test +@device_test def test_fsc_identical_images(): a = torch.rand(size=(10, 10)) result = fsc(a, a) assert torch.allclose(result, torch.ones(6)) +@device_test def test_fsc_identical_volumes(): a = torch.rand(size=(10, 10, 10)) result = fsc(a, a) assert torch.allclose(result, torch.ones(6)) +@device_test def test_fsc_identical_images_with_index_subset(): a = torch.rand(size=(10, 10)) rfft_mask = torch.zeros(size=(10, 6), dtype=torch.bool) @@ -24,6 +28,7 @@ def test_fsc_identical_images_with_index_subset(): assert torch.allclose(result, torch.ones(6)) +@device_test def test_fsc_identical_volumes_with_index_subset(): a = torch.rand(size=(10, 10, 10)) rfft_mask = torch.zeros(size=(10, 10, 6), dtype=torch.bool) diff --git a/src/libtilt/fsc/_tests/test_fsc_conical.py b/src/libtilt/fsc/_tests/test_fsc_conical.py index 0970fd1..9fa401f 100644 --- a/src/libtilt/fsc/_tests/test_fsc_conical.py +++ b/src/libtilt/fsc/_tests/test_fsc_conical.py @@ -1,8 +1,10 @@ import torch from libtilt.fsc.fsc_conical import fsc_conical +from libtilt.pytest_utils import device_test +@device_test def test_fsc_conical(): a = torch.rand((10, 10, 10)) result = fsc_conical(a, a, cone_direction=(1, 0, 0), cone_aperture=30) diff --git a/src/libtilt/fsc/fsc.py b/src/libtilt/fsc/fsc.py index 7f6f3e7..6648988 100644 --- a/src/libtilt/fsc/fsc.py +++ b/src/libtilt/fsc/fsc.py @@ -47,7 +47,8 @@ def fsc( # find indices of all components in each shell sorted_frequencies, sort_idx = torch.sort(frequencies, descending=False) split_idx = torch.searchsorted(sorted_frequencies, split_points) - shell_idx = torch.tensor_split(sort_idx, split_idx)[:-1] + # tensor_split requires the split_idx to live on cpu + shell_idx = torch.tensor_split(sort_idx, split_idx.to('cpu'))[:-1] # calculate normalised cross correlation in each shell fsc = [ diff --git a/src/libtilt/grids/tests/test_central_slice.py b/src/libtilt/grids/tests/test_central_slice.py index 8806f9f..552135d 100644 --- a/src/libtilt/grids/tests/test_central_slice.py +++ b/src/libtilt/grids/tests/test_central_slice.py @@ -1,8 +1,10 @@ import torch from libtilt.grids import central_slice_grid +from libtilt.pytest_utils import device_test +@device_test def test_central_slice_grid(): input_shape = (6, 6, 6) diff --git a/src/libtilt/grids/tests/test_coordinate_grid.py b/src/libtilt/grids/tests/test_coordinate_grid.py index f2d1980..6d8778f 100644 --- a/src/libtilt/grids/tests/test_coordinate_grid.py +++ b/src/libtilt/grids/tests/test_coordinate_grid.py @@ -2,8 +2,10 @@ import torch from libtilt.grids import coordinate_grid +from libtilt.pytest_utils import device_test +@device_test def test_coordinate_grid_simple(): image_shape = (5, 3, 2) result = coordinate_grid( @@ -15,6 +17,7 @@ def test_coordinate_grid_simple(): assert torch.allclose(result[4, 2, 1], torch.tensor([4, 2, 1], dtype=torch.float)) +@device_test def test_coordinate_grid_centered(): image_shape = (28, 28) result = coordinate_grid( @@ -25,6 +28,7 @@ def test_coordinate_grid_centered(): assert torch.allclose(result[0, 0], torch.tensor([-14, -14], dtype=torch.float)) +@device_test def test_coordinate_grid_centered_batched(): image_shape = (28, 28) centers = [[0, 0], [14, 14]] @@ -38,6 +42,7 @@ def test_coordinate_grid_centered_batched(): torch.as_tensor([-14, -14], dtype=torch.float)) +@device_test def test_coordinate_grid_centered_stacked(): image_shape = (28, 28) centers = [[0, 0], [14, 14]] @@ -51,6 +56,7 @@ def test_coordinate_grid_centered_stacked(): assert torch.allclose(result[1, 0, 0, 0, 0], torch.as_tensor([-14, -14]).float()) +@device_test def test_coordinate_with_norm(): image_shape = (5, 5) result = coordinate_grid( diff --git a/src/libtilt/grids/tests/test_fftfreq_grid.py b/src/libtilt/grids/tests/test_fftfreq_grid.py index d1d2dfa..4dfe7b1 100644 --- a/src/libtilt/grids/tests/test_fftfreq_grid.py +++ b/src/libtilt/grids/tests/test_fftfreq_grid.py @@ -1,8 +1,10 @@ import torch from libtilt.grids import fftfreq_grid, central_slice_grid +from libtilt.pytest_utils import device_test +@device_test def test_fftfreq_grid_2d(): input_shape = (6, 6) grid = fftfreq_grid( diff --git a/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py b/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py index aa412dd..49836a3 100644 --- a/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py +++ b/src/libtilt/interpolation/tests/test_interpolate_dft_3d.py @@ -3,8 +3,10 @@ import torch.nn.functional as F from libtilt.interpolation.interpolate_dft_3d import sample_dft_3d +from libtilt.pytest_utils import device_test +@device_test def test_extract_slices(): volume = torch.zeros(4, 4, 4, dtype=torch.complex64) diff --git a/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py b/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py index 18c1766..598bc72 100644 --- a/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py +++ b/src/libtilt/patch_extraction/tests/test_patch_extraction_2d.py @@ -2,8 +2,10 @@ from libtilt.patch_extraction.subpixel_square_patch_extraction import extract_squares, \ _extract_square_patches_from_single_2d_image +from libtilt.pytest_utils import device_test +@device_test def test_single_square_patch_from_single_image(): """Test square patch extraction from single image.""" img = torch.zeros((28, 28)) @@ -17,6 +19,7 @@ def test_single_square_patch_from_single_image(): assert torch.allclose(patches, expected_image, atol=1e-6) +@device_test def test_extract_square_patches_single(): """Test extracting patches from a stack of images.""" img = torch.zeros((2, 28, 28)) @@ -34,6 +37,7 @@ def test_extract_square_patches_single(): assert torch.allclose(patches[0, 1], expected_image_1, atol=1e-6) +@device_test def test_extract_square_patches_batched(): """Test batched particle extraction from single image.""" img = torch.zeros((28, 28)) diff --git a/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py b/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py index faf951f..fa831c3 100644 --- a/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py +++ b/src/libtilt/patch_extraction/tests/test_patch_extraction_3d.py @@ -2,8 +2,10 @@ from libtilt.patch_extraction.subpixel_cubic_patch_extraction import extract_cubes, \ _extract_cubic_patches_from_single_3d_image +from libtilt.pytest_utils import device_test +@device_test def test_single_cubic_patch_from_single_image(): """Test cubic patch extraction from single 3D image.""" img = torch.zeros((28, 28, 28)) @@ -17,6 +19,7 @@ def test_single_cubic_patch_from_single_image(): assert torch.allclose(patches, expected_image, atol=1e-6) +@device_test def test_extract_cubic_patches(): """Test extracting cubic patches from a 3D image.""" img = torch.zeros((28, 28, 28)) diff --git a/src/libtilt/projection/tests/test_project_fourier.py b/src/libtilt/projection/tests/test_project_fourier.py index 144e56e..917b6b3 100644 --- a/src/libtilt/projection/tests/test_project_fourier.py +++ b/src/libtilt/projection/tests/test_project_fourier.py @@ -2,8 +2,10 @@ from scipy.spatial.transform import Rotation as R from libtilt.projection.project_fourier import project_fourier +from libtilt.pytest_utils import device_test +@device_test def test_project_no_rotation(): volume = torch.zeros((10, 10, 10)) volume[5, 5, 5] = 1 @@ -15,6 +17,7 @@ def test_project_no_rotation(): assert torch.allclose(projection, expected) +@device_test def test_project_with_rotation(): volume = torch.zeros((10, 10, 10)) volume[5, 5, 5] = 1 diff --git a/src/libtilt/projection/tests/test_project_real.py b/src/libtilt/projection/tests/test_project_real.py index 1bfe645..304809c 100644 --- a/src/libtilt/projection/tests/test_project_real.py +++ b/src/libtilt/projection/tests/test_project_real.py @@ -1,8 +1,10 @@ import torch from libtilt.projection.project_real import project_image_real, project_volume_real +from libtilt.pytest_utils import device_test +@device_test def test_real_space_projection_3d(): volume_shape = (2, 10, 10) volume = torch.arange(2*10*10).reshape(volume_shape).float() @@ -11,6 +13,7 @@ def test_real_space_projection_3d(): assert torch.allclose(projection.squeeze(), torch.sum(volume, dim=0)) +@device_test def test_real_space_projection_2d(): image_shape = (8, 12) image = torch.arange(8 * 12).reshape(image_shape).float() diff --git a/src/libtilt/pytest_utils.py b/src/libtilt/pytest_utils.py new file mode 100644 index 0000000..8a4c477 --- /dev/null +++ b/src/libtilt/pytest_utils.py @@ -0,0 +1,19 @@ +import functools + +import torch + +AVAILABLE_DEVICES = ["cpu"] +if torch.backends.mps.is_available(): + AVAILABLE_DEVICES.append("mps") +if torch.cuda.is_available(): + AVAILABLE_DEVICES.append("cuda") + + +def device_test(test_func): + @functools.wraps(test_func) + def run_devices(*args, **kwargs): + for device in AVAILABLE_DEVICES: + with torch.device(device): + test_func(*args, **kwargs) + + return run_devices diff --git a/src/libtilt/rotational_averaging/rotational_average_dft.py b/src/libtilt/rotational_averaging/rotational_average_dft.py index 1df5be4..1a6ea40 100644 --- a/src/libtilt/rotational_averaging/rotational_average_dft.py +++ b/src/libtilt/rotational_averaging/rotational_average_dft.py @@ -95,7 +95,8 @@ def _find_shell_indices_1d( """Find indices which index to give values either side of split points.""" sorted, sort_idx = torch.sort(values, descending=False) split_idx = torch.searchsorted(sorted, split_values) - return torch.tensor_split(sort_idx, split_idx) + # tensor_split requires the split_idx to live on cpu + return torch.tensor_split(sort_idx, split_idx.to('cpu')) def _find_shell_indices_2d( @@ -107,7 +108,8 @@ def _find_shell_indices_2d( idx_2d = einops.rearrange(idx_2d, 'h w idx -> (h w) idx') sorted, sort_idx = torch.sort(values, descending=False) split_idx = torch.searchsorted(sorted, split_values) - return torch.tensor_split(idx_2d[sort_idx], split_idx) + # tensor_split requires the split_idx to live on cpu + return torch.tensor_split(idx_2d[sort_idx], split_idx.to('cpu')) def _find_shell_indices_3d( @@ -119,7 +121,8 @@ def _find_shell_indices_3d( idx_3d = einops.rearrange(idx_3d, 'd h w idx -> (d h w) idx') sorted, sort_idx = torch.sort(values, descending=False) split_idx = torch.searchsorted(sorted, split_values) - return torch.tensor_split(idx_3d[sort_idx], split_idx) + # tensor_split requires the split_idx to live on cpu + return torch.tensor_split(idx_3d[sort_idx], split_idx.to('cpu')) def _split_into_frequency_bins_2d( diff --git a/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py b/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py index 28209e2..861be4d 100644 --- a/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py +++ b/src/libtilt/rotational_averaging/tests/test_rotational_average_dft.py @@ -8,8 +8,10 @@ rotational_average_dft_2d, rotational_average_dft_3d, ) +from libtilt.pytest_utils import device_test +@device_test def test_split_into_frequency_bins_2d(): # no rfft, fftshifted frequencies = fftfreq_grid( @@ -55,6 +57,7 @@ def test_split_into_frequency_bins_2d(): assert torch.allclose(shells[0], frequencies[:, 14, 0].reshape((2, 1))) +@device_test def test_split_into_shells_3d(): # no rfft, fftshifted frequencies = fftfreq_grid( @@ -100,6 +103,7 @@ def test_split_into_shells_3d(): assert torch.allclose(shells[0], frequencies[:, 14, 14, 0].reshape((2, 1))) +@device_test def test_rotational_average_dft_2d(): # single image dft = fftfreq_grid(image_shape=(28, 28), rfft=False, fftshift=True, norm=True) @@ -136,6 +140,7 @@ def test_rotational_average_dft_2d(): assert rotational_average.shape == expected_shape +@device_test def test_rotational_average_return_2d(): # no batching image = fftfreq_grid(image_shape=(28, 28), rfft=False, fftshift=True, norm=True) @@ -159,6 +164,7 @@ def test_rotational_average_return_2d(): atol=1e-2) +@device_test def test_rotational_average_dft_3d(): # single image dft = fftfreq_grid(image_shape=(28, 28, 28), rfft=False, fftshift=True, norm=True) diff --git a/src/libtilt/shapes/soft_edge.py b/src/libtilt/shapes/soft_edge.py index ef7d456..5b42063 100644 --- a/src/libtilt/shapes/soft_edge.py +++ b/src/libtilt/shapes/soft_edge.py @@ -8,7 +8,8 @@ def _add_soft_edge_single_binary_image( ) -> torch.FloatTensor: if smoothing_radius == 0: return image.float() - distances = ndi.distance_transform_edt(torch.logical_not(image)) + # move explicitly to cpu for scipy + distances = ndi.distance_transform_edt(torch.logical_not(image).to("cpu")) distances = torch.as_tensor(distances, device=image.device).float() idx = torch.logical_and(distances > 0, distances <= smoothing_radius) output = torch.clone(image).float() @@ -19,7 +20,7 @@ def _add_soft_edge_single_binary_image( def add_soft_edge_2d( image: torch.Tensor, smoothing_radius: torch.Tensor | float ) -> torch.Tensor: - image_packed, ps = einops.pack([image], '* h w') + image_packed, ps = einops.pack([image], "* h w") b = image_packed.shape[0] if isinstance(smoothing_radius, float | int): @@ -30,18 +31,17 @@ def add_soft_edge_2d( results = [ _add_soft_edge_single_binary_image(_image, smoothing_radius=_smoothing_radius) - for _image, _smoothing_radius - in zip(image_packed, smoothing_radius) + for _image, _smoothing_radius in zip(image_packed, smoothing_radius) ] results = torch.stack(results, dim=0) - [results] = einops.unpack(results, pattern='* h w', packed_shapes=ps) + [results] = einops.unpack(results, pattern="* h w", packed_shapes=ps) return results def add_soft_edge_3d( image: torch.Tensor, smoothing_radius: torch.Tensor | float ) -> torch.Tensor: - image_packed, ps = einops.pack([image], '* d h w') + image_packed, ps = einops.pack([image], "* d h w") b = image_packed.shape[0] if isinstance(smoothing_radius, float | int): smoothing_radius = torch.as_tensor( @@ -50,9 +50,8 @@ def add_soft_edge_3d( smoothing_radius = torch.broadcast_to(smoothing_radius, (b,)) results = [ _add_soft_edge_single_binary_image(_image, smoothing_radius=_smoothing_radius) - for _image, _smoothing_radius - in zip(image_packed, smoothing_radius) + for _image, _smoothing_radius in zip(image_packed, smoothing_radius) ] results = torch.stack(results, dim=0) - [results] = einops.unpack(results, pattern='* d h w', packed_shapes=ps) - return results \ No newline at end of file + [results] = einops.unpack(results, pattern="* d h w", packed_shapes=ps) + return results diff --git a/src/libtilt/shapes/tests/test_soft_edge.py b/src/libtilt/shapes/tests/test_soft_edge.py index 261d326..e105812 100644 --- a/src/libtilt/shapes/tests/test_soft_edge.py +++ b/src/libtilt/shapes/tests/test_soft_edge.py @@ -1,20 +1,28 @@ import torch -from libtilt.shapes.soft_edge import _add_soft_edge_single_binary_image, add_soft_edge_2d +from libtilt.pytest_utils import device_test +from libtilt.shapes.soft_edge import ( + _add_soft_edge_single_binary_image, + add_soft_edge_2d, +) +@device_test def test_add_soft_edge_single_binary_image(): dim_length = 5 smoothing_radius = 4 image = torch.zeros(size=(dim_length, 1)) image[0, 0] = 1 - smoothed = _add_soft_edge_single_binary_image(image, smoothing_radius=smoothing_radius) + smoothed = _add_soft_edge_single_binary_image( + image, smoothing_radius=smoothing_radius + ) # cosine falloff, 1 to zero over smooothing radius expected = torch.cos((torch.pi / 2) * torch.arange(5) / smoothing_radius) expected = expected.view((dim_length, 1)) assert torch.allclose(smoothed, expected) +@device_test def test_add_soft_edge_2d(): # single image image = torch.zeros(size=(5, 5)) @@ -29,4 +37,4 @@ def test_add_soft_edge_2d(): images[..., 0, 0] = 1 results = add_soft_edge_2d(images, smoothing_radius=smoothing_radius) assert results.shape == images.shape - assert torch.allclose(results[0], results[1]) \ No newline at end of file + assert torch.allclose(results[0], results[1]) diff --git a/src/libtilt/shift/tests/test_phase_shift_2d.py b/src/libtilt/shift/tests/test_phase_shift_2d.py index c5db0f2..8b5577c 100644 --- a/src/libtilt/shift/tests/test_phase_shift_2d.py +++ b/src/libtilt/shift/tests/test_phase_shift_2d.py @@ -1,9 +1,11 @@ import torch -from libtilt.shift.shift_image import shift_2d +from libtilt.pytest_utils import device_test from libtilt.shift.phase_shift_dft import get_phase_shifts_2d +from libtilt.shift.shift_image import shift_2d +@device_test def test_get_phase_shifts_2d_full_fft(): shifts = torch.zeros(size=(1, 2)) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=False) @@ -11,11 +13,13 @@ def test_get_phase_shifts_2d_full_fft(): shifts = torch.tensor([[1, 2]]) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=False) - expected = torch.tensor([[[1 + 0.0000e+00j, 1 + 1.7485e-07j], - [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]]) + expected = torch.tensor( + [[[1 + 0.0000e00j, 1 + 1.7485e-07j], [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]] + ) assert torch.allclose(phase_shifts, expected) +@device_test def test_get_phase_shifts_2d_rfft(): shifts = torch.zeros(size=(1, 2)) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=True) @@ -25,11 +29,13 @@ def test_get_phase_shifts_2d_rfft(): shifts = torch.tensor([[1, 2]]) phase_shifts = get_phase_shifts_2d(shifts, image_shape=(2, 2), rfft=False) - expected = torch.tensor([[[1 + 0.0000e+00j, 1 + 1.7485e-07j], - [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]]) + expected = torch.tensor( + [[[1 + 0.0000e00j, 1 + 1.7485e-07j], [-1 - 8.7423e-08j, -1 - 2.3850e-08j]]] + ) assert torch.allclose(phase_shifts, expected) +@device_test def test_phase_shift_images_2d(): image = torch.zeros((4, 4)) image[2, 2] = 1