Skip to content

Commit

Permalink
Typing fix
Browse files Browse the repository at this point in the history
  • Loading branch information
georgeyiasemis committed Jun 5, 2024
1 parent c686711 commit c1dd9a5
Showing 1 changed file with 38 additions and 31 deletions.
69 changes: 38 additions & 31 deletions direct/data/transforms.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,26 @@
# coding=utf-8
# Copyright (c) DIRECT Contributors

# Code and comments can be shared with code of FastMRI under the same MIT license:
# https://github.com/facebookresearch/fastMRI/
# The code can have been adjusted to our needs.

from typing import Callable, List, Optional, Tuple, Union
"""Direct transforms module.
This module contains functions for complex-valued data manipulation in PyTorch. This includes functions for complex
multiplication, division, modulus, fft, ifft, fftshift, ifftshift, and more. The functions are designed to work with
complex-valued data where the last axis denotes the real and imaginary parts respectively. The functions are designed to
work with complex-valued data where the last axis denotes the real and imaginary parts respectively."""

from __future__ import annotations

from typing import Callable, Optional, Union

import numpy as np
import torch
import torch.fft
from numpy.typing import ArrayLike

from direct.data.bbox import crop_to_bbox
from direct.types import DirectEnum
from direct.utils import ensure_list, is_complex_data, is_power_of_two
from direct.utils.asserts import assert_complex, assert_same_shape

Expand All @@ -35,7 +42,7 @@ def to_tensor(data: np.ndarray) -> torch.Tensor:
return torch.from_numpy(data)


def verify_fft_dtype_possible(data: torch.Tensor, dims: Tuple[int, ...]) -> bool:
def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, ...]) -> bool:
"""fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if
this is the case.
Expand Down Expand Up @@ -98,7 +105,7 @@ def view_as_real(data):

def fft2(
data: torch.Tensor,
dim: Tuple[int, ...] = (1, 2),
dim: tuple[int, ...] = (1, 2),
centered: bool = True,
normalized: bool = True,
complex_input: bool = True,
Expand Down Expand Up @@ -159,7 +166,7 @@ def fft2(

def ifft2(
data: torch.Tensor,
dim: Tuple[int, ...] = (1, 2),
dim: tuple[int, ...] = (1, 2),
centered: bool = True,
normalized: bool = True,
complex_input: bool = True,
Expand Down Expand Up @@ -300,16 +307,16 @@ def roll_one_dim(data: torch.Tensor, shift: int, dim: int) -> torch.Tensor:

def roll(
data: torch.Tensor,
shift: List[int],
dim: Union[List[int], Tuple[int, ...]],
shift: list[int],
dim: Union[list[int], tuple[int, ...]],
) -> torch.Tensor:
"""Similar to numpy roll but applies to pytorch tensors.
Parameters
----------
data: torch.Tensor
shift: tuple, int
dim: List or tuple of ints
dim: list or tuple of ints
Returns
-------
Expand All @@ -325,14 +332,14 @@ def roll(
return data


def fftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = None) -> torch.Tensor:
def fftshift(data: torch.Tensor, dim: Union[list[int], tuple[int, ...], None] = None) -> torch.Tensor:
"""Similar to numpy fftshift but applies to pytorch tensors.
Parameters
----------
data: torch.Tensor
Input data.
dim: List or tuple of ints or None
dim: list or tuple of ints or None
Default: None.
Returns
Expand All @@ -353,14 +360,14 @@ def fftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] =
return roll(data, shift, dim)


def ifftshift(data: torch.Tensor, dim: Union[List[int], Tuple[int, ...], None] = None) -> torch.Tensor:
def ifftshift(data: torch.Tensor, dim: Union[list[int], tuple[int, ...], None] = None) -> torch.Tensor:
"""Similar to numpy ifftshift but applies to pytorch tensors.
Parameters
----------
data: torch.Tensor
Input data.
dim: List or tuple of ints or None
dim: list or tuple of ints or None
Default: None.
Returns
Expand Down Expand Up @@ -413,7 +420,7 @@ def complex_multiplication(input_tensor: torch.Tensor, other_tensor: torch.Tenso
return multiplication


def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> torch.Tensor:
def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: list[int]) -> torch.Tensor:
r"""Computes the dot product of the complex tensors :math:`a` and :math:`b`: :math:`a^{*}b = <a, b>`.
Parameters
Expand All @@ -422,7 +429,7 @@ def complex_dot_product(a: torch.Tensor, b: torch.Tensor, dim: List[int]) -> tor
Input :math:`a`.
b : torch.Tensor
Input :math:`b`.
dim : List[int]
dim : list[int]
Dimensions which will be suppressed. Useful when inputs are batched.
Returns
Expand Down Expand Up @@ -583,7 +590,7 @@ def apply_mask(
mask_func: Union[Callable, torch.Tensor],
seed: Optional[int] = None,
return_mask: bool = True,
) -> Union[Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
) -> Union[tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""Subsample kspace by setting kspace to zero as given by a binary mask.
Parameters
Expand Down Expand Up @@ -665,13 +672,13 @@ def root_sum_of_squares(data: torch.Tensor, dim: int = 0, complex_dim: int = -1)
return torch.sqrt((data**2).sum(dim))


def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) -> torch.Tensor:
def center_crop(data: torch.Tensor, shape: Union[list[int], tuple[int, ...]]) -> torch.Tensor:
"""Apply a center crop along the last two dimensions.
Parameters
----------
data: torch.Tensor
shape: List or tuple of ints
shape: list or tuple of ints
The output shape, should be smaller than the corresponding data dimensions.
Returns
Expand All @@ -691,19 +698,19 @@ def center_crop(data: torch.Tensor, shape: Union[List[int], Tuple[int, ...]]) ->


def complex_center_crop(
data_list: Union[List[torch.Tensor], torch.Tensor],
crop_shape: Union[List[int], Tuple[int, ...]],
data_list: Union[list[torch.Tensor], torch.Tensor],
crop_shape: Union[list[int], tuple[int, ...]],
offset: int = 1,
contiguous: bool = False,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[list[torch.Tensor], torch.Tensor]:
"""Apply a center crop to the input data, or to a list of complex images.
Parameters
----------
data_list: Union[List[torch.Tensor], torch.Tensor]
data_list: Union[list[torch.Tensor], torch.Tensor]
The complex input tensor to be center cropped. It should have at least 3 dimensions
and the cropping is applied along dimensions didx and didx+1 and the last dimensions should have a size of 2.
crop_shape: List[int] or Tuple[int, ...]
crop_shape: list[int] or tuple[int, ...]
The output shape. The shape should be smaller than the corresponding dimensions of data.
If one value is None, this is filled in by the image shape.
offset: int
Expand All @@ -713,7 +720,7 @@ def complex_center_crop(
Returns
-------
Union[List[torch.Tensor], torch.Tensor]
Union[list[torch.Tensor], torch.Tensor]
The center cropped input_image(s).
"""
data_list = ensure_list(data_list)
Expand Down Expand Up @@ -747,22 +754,22 @@ def complex_center_crop(


def complex_random_crop(
data_list: Union[List[torch.Tensor], torch.Tensor],
crop_shape: Union[List[int], Tuple[int, ...]],
data_list: Union[list[torch.Tensor], torch.Tensor],
crop_shape: Union[list[int], tuple[int, ...]],
offset: int = 1,
contiguous: bool = False,
sampler: str = "uniform",
sigma: Union[float, List[float], None] = None,
sigma: Union[float, list[float], None] = None,
seed: Union[None, int, ArrayLike] = None,
) -> Union[List[torch.Tensor], torch.Tensor]:
) -> Union[list[torch.Tensor], torch.Tensor]:
"""Apply a random crop to the input data tensor or a list of complex.
Parameters
----------
data_list: Union[List[torch.Tensor], torch.Tensor]
data_list: Union[list[torch.Tensor], torch.Tensor]
The complex input tensor to be center cropped. It should have at least 3 dimensions and the cropping is applied
along dimensions -3 and -2 and the last dimensions should have a size of 2.
crop_shape: List[int] or Tuple[int, ...]
crop_shape: list[int] or tuple[int, ...]
The output shape. The shape should be smaller than the corresponding dimensions of data.
offset: int
Starting dimension for cropping.
Expand All @@ -776,7 +783,7 @@ def complex_random_crop(
Returns
-------
Union[List[torch.Tensor], torch.Tensor]
Union[list[torch.Tensor], torch.Tensor]
The center cropped input tensor or list of tensors.
"""
if sampler == "uniform" and sigma is not None:
Expand Down

0 comments on commit c1dd9a5

Please sign in to comment.