Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test decorator for cpu and cuda #75

Merged
merged 22 commits into from
Jun 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,42 +7,42 @@ default_install_hook_types: [pre-commit, commit-msg]

repos:
- repo: https://github.com/compilerla/conventional-pre-commit
rev: v1.3.0
rev: v3.2.0
hooks:
- id: conventional-pre-commit
stages: [commit-msg]

- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.3.0
rev: v4.6.0
hooks:
- id: check-docstring-first
- id: end-of-file-fixer
- id: trailing-whitespace

- repo: https://github.com/myint/autoflake
rev: v1.4
rev: v2.3.1
hooks:
- id: autoflake
args: ["--in-place", "--remove-all-unused-imports"]

- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.13.2
hooks:
- id: isort

- repo: https://github.com/asottile/pyupgrade
rev: v2.34.0
rev: v3.16.0
hooks:
- id: pyupgrade
args: [--py38-plus, --keep-runtime-typing]

- repo: https://github.com/psf/black
rev: 22.3.0
rev: 24.4.2
hooks:
- id: black

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
rev: 7.1.0
hooks:
- id: flake8
additional_dependencies:
Expand All @@ -52,7 +52,7 @@ repos:
- flake8-typing-imports

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.961
rev: v1.10.0
hooks:
- id: mypy
files: "^source_spacing/"
28 changes: 28 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,31 @@
[![codecov](https://codecov.io/gh/alisterburt/libtilt/branch/main/graph/badge.svg)](https://codecov.io/gh/alisterburt/libtilt)

Image processing for cryo-electron tomography in PyTorch


## For developers

We advise to fork the repository and make a local clone of your fork. After setting
up an environment (e.g. via miniconda), the development version can be installed with
(`-e` runs an editable install):

```commandline
python -m pip install -e '.[dev,test]'
```

Then ready pre-commits for automated code checks and styling:

```commandline
pre-commit install
```

Before making any pull request please make sure all the unittests pass:

```commandline
python -m pytest
```

The libtilt `device_test` decorator should be added to each test and tries to run
the code on GPU. This can either be via CUDA or via the 'mps' backend for M1 chips.
Please keep in mind that without a GPU available on your system, the code is not
explicitly tested to run on GPU.
25 changes: 17 additions & 8 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,16 +94,25 @@ src_paths = ["src/libtilt", "tests"]
[tool.flake8]
exclude = "docs,.eggs,examples,_version.py"
max-line-length = 88
ignore = "E203"
ignore = ["E203",]
min-python-version = "3.8.0"
docstring-convention = "all" # use numpy convention, while allowing D417
extend-ignore = """
E203 # whitespace before ':'
D107,D203,D212,D213,D402,D413,D415,D416 # numpy
D100 # missing docstring in public module
D401 # imperative mood
W503 # line break before binary operator
"""
extend-ignore = [
"E203", # whitespace before ':'
"D107",
"D203",
"D212",
"D213",
"D402",
"D413",
"D415",
"D416", # numpy,
"E203",
"D100", # missing docstring in public module
"D401", # imperative mood
"W503", # line break before binary operator
"D103", # skip missin docstrings TODO update docstrings and remove
]
per-file-ignores = [
"tests/*:D",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
insert_into_dft_3d
from libtilt.grids import rotated_central_slice_grid
from libtilt.fft_utils import fftfreq_to_dft_coordinates
from libtilt.pytest_utils import device_test


@device_test
def test_fourier_slice_extraction_insertion_cycle_no_rotation():
sidelength = 64
volume = torch.zeros((sidelength, sidelength, sidelength), dtype=torch.complex64)
Expand Down Expand Up @@ -36,6 +38,7 @@ def test_fourier_slice_extraction_insertion_cycle_no_rotation():
assert torch.allclose(input_slice, output_slice)


@device_test
def test_fourier_slice_extraction_insertion_cycle_with_rotation():
sidelength = 64

Expand Down
7 changes: 7 additions & 0 deletions src/libtilt/_tests/test_coordinate_tils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
homogenise_coordinates,
)
from libtilt.grids.coordinate_grid import coordinate_grid
from libtilt.pytest_utils import device_test


@device_test
def test_array_coordinates_to_grid_sample_coordinates_nd():
array_shape = z, y, x = (4, 8, 12)
array_coordinates = einops.rearrange(torch.tensor(np.indices(array_shape)),
Expand All @@ -27,6 +29,7 @@ def test_array_coordinates_to_grid_sample_coordinates_nd():
assert torch.allclose(grid_sample_coordinates[:, 0, 0, 2], expected_z)


@device_test
def test_grid_sample_coordinates_to_array_coordinates_nd():
array_shape = (4, 8, 12)
expected_array_coordinates = einops.rearrange(
Expand All @@ -41,6 +44,7 @@ def test_grid_sample_coordinates_to_array_coordinates_nd():
assert torch.allclose(array_coordinates, expected_array_coordinates)


@device_test
def test_add_implied_coordinate_from_dimension():
batch_of_stacked_2d_coords = torch.zeros(size=(1, 5, 2)) # (b, stack, 2)
result = add_positional_coordinate(batch_of_stacked_2d_coords, dim=1)
Expand All @@ -49,6 +53,7 @@ def test_add_implied_coordinate_from_dimension():
assert torch.allclose(result, expected)


@device_test
def test_add_implied_coordinate_from_dimension_prepend():
batch_of_stacked_2d_coords = torch.zeros(size=(1, 5, 2)) # (b, stack, 2)
result = add_positional_coordinate(batch_of_stacked_2d_coords, dim=1,
Expand All @@ -58,6 +63,7 @@ def test_add_implied_coordinate_from_dimension_prepend():
assert torch.allclose(result, expected)


@device_test
def test_get_grid_coordinates():
coords = coordinate_grid(image_shape=(3, 2))
assert coords.shape == (3, 2, 2)
Expand All @@ -74,6 +80,7 @@ def test_get_grid_coordinates():
assert torch.allclose(coords, expected)


@device_test
def test_homogenise_coordinates():
coords = torch.rand(size=(2, 3))
homogenised = homogenise_coordinates(coords)
Expand Down
41 changes: 31 additions & 10 deletions src/libtilt/_tests/test_fft_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,14 @@
fftfreq_to_dft_coordinates,
)
from libtilt.grids.fftfreq_grid import _construct_fftfreq_grid_2d
from libtilt.pytest_utils import device_test


# seed the random number generator to ensure tests are consistent
torch.manual_seed(0)


@device_test
def test_rfft_shape_from_signal_shape():
# even
signal_shape = (2, 4, 8, 16)
Expand All @@ -31,6 +37,7 @@ def test_rfft_shape_from_signal_shape():
assert rfft.shape == rfft_shape(signal_shape)


@device_test
def test_construct_fftfreq_grid_2d():
image_shape = (10, 30)
# no rfft
Expand All @@ -46,6 +53,7 @@ def test_construct_fftfreq_grid_2d():
assert torch.allclose(grid[0, :, 1], torch.fft.rfftfreq(30))


@device_test
def test_rfft_to_symmetrised_dft_2d():
image = torch.rand((10, 10))
fft = torch.fft.fftshift(torch.fft.fftn(image, dim=(-2, -1)), dim=(-2, -1))
Expand All @@ -54,6 +62,7 @@ def test_rfft_to_symmetrised_dft_2d():
assert torch.allclose(fft, symmetrised_dft[:-1, :-1], atol=1e-7)


@device_test
def test_rfft_to_symmetrised_dft_2d_batched():
image = torch.rand((2, 10, 10)) # (b, h, w)
fft = torch.fft.fftshift(torch.fft.fftn(image, dim=(-2, -1)), dim=(-2, -1))
Expand All @@ -62,20 +71,21 @@ def test_rfft_to_symmetrised_dft_2d_batched():
assert torch.allclose(fft, symmetrised_dft[..., :-1, :-1], atol=1e-7)


@device_test
def test_rfft_to_symmetrised_dft_3d():
image = torch.rand((10, 10, 10))
fft_dims = (-3, -2, -1)
fft = torch.fft.fftshift(torch.fft.fftn(image, dim=fft_dims), dim=fft_dims)
rfft = torch.fft.rfftn(image, dim=(-3, -2, -1))
symmetrised_dft = _rfft_to_symmetrised_dft_3d(rfft)
np.array(fft - symmetrised_dft[:-1, :-1, :-1])
assert torch.allclose(fft, symmetrised_dft[:-1, :-1, :-1], atol=1e-5)


@pytest.mark.parametrize(
"inplace",
[(True,), (False,)]
)
@device_test
def test_symmetrised_dft_to_dft_2d(inplace: bool):
image = torch.rand((10, 10))
rfft = torch.fft.rfftn(image, dim=(-2, -1))
Expand All @@ -90,6 +100,7 @@ def test_symmetrised_dft_to_dft_2d(inplace: bool):
"inplace",
[(True,), (False,)]
)
@device_test
def test_symmetrised_dft_to_dft_2d_batched(inplace: bool):
image = torch.rand((2, 10, 10))
rfft = torch.fft.rfftn(image, dim=(-2, -1))
Expand All @@ -104,6 +115,7 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool):
"inplace",
[(True,), (False,)]
)
@device_test
def test_symmetrised_dft_to_rfft_2d(inplace: bool):
image = torch.rand((10, 10))
rfft = torch.fft.rfftn(image, dim=(-2, -1))
Expand All @@ -117,6 +129,7 @@ def test_symmetrised_dft_to_rfft_2d(inplace: bool):
"inplace",
[(True,), (False,)]
)
@device_test
def test_symmetrised_dft_to_dft_2d_batched(inplace: bool):
image = torch.rand((2, 10, 10))
rfft = torch.fft.rfftn(image, dim=(-2, -1))
Expand All @@ -131,6 +144,7 @@ def test_symmetrised_dft_to_dft_2d_batched(inplace: bool):
"inplace",
[(True,), (False,)]
)
@device_test
def test_symmetrised_dft_to_dft_3d(inplace: bool):
image = torch.rand((10, 10, 10))
rfft = torch.fft.rfftn(image, dim=(-3, -2, -1))
Expand All @@ -147,6 +161,7 @@ def test_symmetrised_dft_to_dft_3d(inplace: bool):
"inplace",
[(True,), (False,)]
)
@device_test
def test_symmetrised_dft_to_dft_3d_batched(inplace: bool):
image = torch.rand((2, 10, 10, 10))
rfft = torch.fft.rfftn(image, dim=(-3, -2, -1))
Expand All @@ -161,21 +176,25 @@ def test_symmetrised_dft_to_dft_3d_batched(inplace: bool):
@pytest.mark.parametrize(
"fftshifted, rfft, input, expected",
[
(False, False, (5, 5, 5), torch.tensor([0., 0., 0.])),
(False, True, (5, 5, 5), torch.tensor([0., 0., 0.])),
(True, False, (5, 5, 5), torch.tensor([2., 2., 2.])),
(True, True, (5, 5, 5), torch.tensor([2., 2., 0.])),
(False, False, (4, 4, 4), torch.tensor([0., 0., 0.])),
(False, True, (4, 4, 4), torch.tensor([0., 0., 0.])),
(True, False, (4, 4, 4), torch.tensor([2., 2., 2.])),
(True, True, (4, 4, 4), torch.tensor([2., 2., 0.])),
(False, False, (5, 5, 5), [0., 0., 0.]),
(False, True, (5, 5, 5), [0., 0., 0.]),
(True, False, (5, 5, 5), [2., 2., 2.]),
(True, True, (5, 5, 5), [2., 2., 0.]),
(False, False, (4, 4, 4), [0., 0., 0.]),
(False, True, (4, 4, 4), [0., 0., 0.]),
(True, False, (4, 4, 4), [2., 2., 2.]),
(True, True, (4, 4, 4), [2., 2., 0.]),
],
)
@device_test
def test_fft_center(fftshifted, rfft, input, expected):
result = dft_center(input, fftshifted=fftshifted, rfft=rfft)
assert torch.allclose(result, expected.long())
assert torch.allclose(result, torch.tensor(
expected, device=result.device, dtype=result.dtype
))


@device_test
def test_fftfreq_to_spatial_frequency():
fftfreq = torch.fft.fftfreq(10)
k = fftfreq_to_spatial_frequency(fftfreq, spacing=2)
Expand All @@ -187,6 +206,7 @@ def test_fftfreq_to_spatial_frequency():
assert torch.allclose(k, expected)


@device_test
def test_spatial_frequency_to_fftfreq():
k = torch.fft.fftfreq(10, d=2)
fftfreq = spatial_frequency_to_fftfreq(k, spacing=2)
Expand All @@ -199,6 +219,7 @@ def test_spatial_frequency_to_fftfreq():
assert torch.allclose(fftfreq, expected)


@device_test
def test_fftfreq_to_dft_coords():
from libtilt.grids import fftfreq_grid, coordinate_grid

Expand Down
5 changes: 5 additions & 0 deletions src/libtilt/_tests/test_transformations.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch

from libtilt.transformations import Rx, Ry, Rz, T
from libtilt.pytest_utils import device_test


@device_test
def test_rotation_around_x():
"""Rotation of y around x should become z."""
R = Rx(90)
Expand All @@ -16,6 +18,7 @@ def test_rotation_around_x():
assert torch.allclose(R @ v, expected, atol=1e-6)


@device_test
def test_rotation_around_y():
"""Rotation of z around y should be x"""
R = Ry(90)
Expand All @@ -29,6 +32,7 @@ def test_rotation_around_y():
assert torch.allclose(R @ v, expected, atol=1e-6)


@device_test
def test_rotation_around_z():
"""Rotation of x around z should give y."""
R = Rz(90)
Expand All @@ -42,6 +46,7 @@ def test_rotation_around_z():
assert torch.allclose(R @ v, expected, atol=1e-6)


@device_test
def test_translation():
"""Translations"""
M = T([1, 1, 1])
Expand Down
Loading