diff --git a/direct/data/transforms.py b/direct/data/transforms.py index 22fa6b08..7aca292e 100644 --- a/direct/data/transforms.py +++ b/direct/data/transforms.py @@ -42,14 +42,14 @@ def to_tensor(data: np.ndarray) -> torch.Tensor: return torch.from_numpy(data) -def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, ...]) -> bool: +def verify_fft_dtype_possible(data: torch.Tensor, dims: tuple[int, int] | tuple[int, int, int]) -> bool: """fft and ifft can only be performed on GPU in float16 if the shapes are powers of 2. This function verifies if this is the case. Parameters ---------- data: torch.Tensor - dims: tuple + dims: tuple of two or three ints Returns -------