diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 746d5c59..26e81a9d 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -1,11 +1,19 @@ -# coding=utf-8 # Copyright (c) DIRECT Contributors # Code and comments can be shared with code of FastMRI under the same MIT license: # https://github.com/facebookresearch/fastMRI/ # The code can have been adjusted to our needs. -from typing import Callable, List, Optional, Tuple, Union +"""Direct transforms module. + +This module contains functions for complex-valued data manipulation in PyTorch. This includes functions for complex +multiplication, division, modulus, fft, ifft, fftshift, ifftshift, and more. The functions are designed to work with +complex-valued data where the last axis denotes the real and imaginary parts respectively. The functions are designed to +work with complex-valued data where the last axis denotes the real and imaginary parts respectively.""" + +from __future__ import annotations + +from typing import Callable, Optional, Union import numpy as np import torch @@ -13,7 +21,6 @@ from numpy.typing import ArrayLike from direct.data.bbox import crop_to_bbox -from direct.types import DirectEnum from direct.utils import ensure_list, is_complex_data, is_power_of_two from direct.utils.asserts import assert_complex, assert_same_shape @@ -35,7 +42,7 @@ def to_tensor(data: np.ndarray) -> torch.Tensor: return torch.from_numpy(data) -def verify_fft_dtype_possible(data: torch.Tensor, dims: Tuple[int, ...]) -> bool: +def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, ...]) -> bool: """fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if this is the case. @@ -98,7 +105,7 @@ def view_as_real(data): def fft2( data: torch.Tensor, - dim: Tuple[int, ...] = (1, 2), + dim: tuple[int, ...] = (1, 2), centered: bool = True, normalized: bool = True, complex_input: bool = True, @@ -159,7 +166,7 @@ def fft2( def ifft2( data: torch.Tensor, - dim: Tuple[int, ...] = (1, 2), + dim: tuple[int, ...] = (1, 2), centered: bool = True, normalized: bool = True, complex_input: bool = True, @@ -300,8 +307,8 @@ def roll_one_dim(data: torch.Tensor, shift: int, dim: int) -> torch.Tensor: def roll( data: torch.Tensor, - shift: List[int], - dim: Union[List[int], Tuple[int, ...]], + shift: list[int], + dim: Union[list[int], tuple[int, ...]], ) -> torch.Tensor: """Similar to numpy roll but applies to pytorch tensors. @@ -309,7 +316,7 @@ def roll( ---------- data: torch.Tensor shift: tuple, int - dim: List or tuple of ints + dim: list or tuple of ints Returns ------- @@ -325,14 +332,14 @@ def roll( return data -def fftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = None) -> torch.Tensor: +def fftshift(data: torch.Tensor, dim: Union[list[int], tuple[int, ...], None] = None) -> torch.Tensor: """Similar to numpy fftshift but applies to pytorch tensors. Parameters ---------- data: torch.Tensor Input data. - dim: List or tuple of ints or None + dim: list or tuple of ints or None Default: None. Returns @@ -353,14 +360,14 @@ def fftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = return roll(data, shift, dim) -def ifftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = None) -> torch.Tensor: +def ifftshift(data: torch.Tensor, dim: Union[list[int], tuple[int, ...], None] = None) -> torch.Tensor: """Similar to numpy ifftshift but applies to pytorch tensors. Parameters ---------- data: torch.Tensor Input data. - dim: List or tuple of ints or None + dim: list or tuple of ints or None Default: None. Returns @@ -413,7 +420,7 @@ def complex_multiplication(input_tensor: torch.Tensor, other_tensor: torch.Tenso return multiplication -def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> torch.Tensor: +def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: list[int]) -> torch.Tensor: r"""Computes the dot product of the complex tensors :math:`a` and :math:`b`: :math:`a^{*}b = `. Parameters @@ -422,7 +429,7 @@ def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> tor Input :math:`a`. b : torch.Tensor Input :math:`b`. - dim : List[int] + dim : list[int] Dimensions which will be suppressed. Useful when inputs are batched. Returns @@ -583,7 +590,7 @@ def apply_mask( mask_func: Union[Callable, torch.Tensor], seed: Optional[int] = None, return_mask: bool = True, -) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]: +) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]: """Subsample kspace by setting kspace to zero as given by a binary mask. Parameters @@ -665,13 +672,13 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1) return torch.sqrt((data**2).sum(dim)) -def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) -> torch.Tensor: +def center_crop(data: torch.Tensor, shape: Union[list[int], tuple[int, ...]]) -> torch.Tensor: """Apply a center crop along the last two dimensions. Parameters ---------- data: torch.Tensor - shape: List or tuple of ints + shape: list or tuple of ints The output shape, should be smaller than the corresponding data dimensions. Returns @@ -691,19 +698,19 @@ def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) -> def complex_center_crop( - data_list: Union[List[torch.Tensor], torch.Tensor], - crop_shape: Union[List[int], Tuple[int, ...]], + data_list: Union[list[torch.Tensor], torch.Tensor], + crop_shape: Union[list[int], tuple[int, ...]], offset: int = 1, contiguous: bool = False, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[list[torch.Tensor], torch.Tensor]: """Apply a center crop to the input data, or to a list of complex images. Parameters ---------- - data_list: Union[List[torch.Tensor], torch.Tensor] + data_list: Union[list[torch.Tensor], torch.Tensor] The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2. - crop_shape: List[int] or Tuple[int, ...] + crop_shape: list[int] or tuple[int, ...] The output shape. The shape should be smaller than the corresponding dimensions of data. If one value is None, this is filled in by the image shape. offset: int @@ -713,7 +720,7 @@ def complex_center_crop( Returns ------- - Union[List[torch.Tensor], torch.Tensor] + Union[list[torch.Tensor], torch.Tensor] The center cropped input_image(s). """ data_list = ensure_list(data_list) @@ -747,22 +754,22 @@ def complex_center_crop( def complex_random_crop( - data_list: Union[List[torch.Tensor], torch.Tensor], - crop_shape: Union[List[int], Tuple[int, ...]], + data_list: Union[list[torch.Tensor], torch.Tensor], + crop_shape: Union[list[int], tuple[int, ...]], offset: int = 1, contiguous: bool = False, sampler: str = "uniform", - sigma: Union[float, List[float], None] = None, + sigma: Union[float, list[float], None] = None, seed: Union[None, int, ArrayLike] = None, -) -> Union[List[torch.Tensor], torch.Tensor]: +) -> Union[list[torch.Tensor], torch.Tensor]: """Apply a random crop to the input data tensor or a list of complex. Parameters ---------- - data_list: Union[List[torch.Tensor], torch.Tensor] + data_list: Union[list[torch.Tensor], torch.Tensor] The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied along dimensions -3 and -2 and the last dimensions should have a size of 2. - crop_shape: List[int] or Tuple[int, ...] + crop_shape: list[int] or tuple[int, ...] The output shape. The shape should be smaller than the corresponding dimensions of data. offset: int Starting dimension for cropping. @@ -776,7 +783,7 @@ def complex_random_crop( Returns ------- - Union[List[torch.Tensor], torch.Tensor] + Union[list[torch.Tensor], torch.Tensor] The center cropped input tensor or list of tensors. """ if sampler == "uniform" and sigma is not None: