diff --git a/references/classification/transforms.py b/references/classification/transforms.py index 5443437d29d..96236608eec 100644 --- a/references/classification/transforms.py +++ b/references/classification/transforms.py @@ -19,9 +19,9 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2): ) if cutmix_alpha > 0: mixup_cutmix.append( - transforms_module.CutMix(alpha=mixup_alpha, num_classes=num_classes) + transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes) if use_v2 - else RandomCutMix(num_classes=num_classes, p=1.0, alpha=mixup_alpha) + else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha) ) if not mixup_cutmix: return None diff --git a/test/test_ops.py b/test/test_ops.py index 52a66f380a6..99b259f73f5 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -14,6 +14,7 @@ from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps from PIL import Image from torch import nn, Tensor +from torch._dynamo.utils import is_compile_supported from torch.autograd import gradcheck from torch.nn.modules.utils import _pair from torchvision import models, ops @@ -529,6 +530,10 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype): def test_backward(self, seed, device, contiguous, deterministic): if deterministic and device == "cpu": pytest.skip("cpu is always deterministic, don't retest") + if deterministic and device == "mps": + pytest.skip("no deterministic implementation for mps") + if deterministic and not is_compile_supported(device): + pytest.skip("deterministic implementation only if torch.compile supported") super().test_backward(seed, device, contiguous, deterministic) def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000): diff --git a/torchvision/ops/roi_align.py b/torchvision/ops/roi_align.py index 0d505c140ee..ac1ec8b429a 100644 --- a/torchvision/ops/roi_align.py +++ b/torchvision/ops/roi_align.py @@ -1,9 +1,11 @@ +import functools from typing import List, Union import torch import torch._dynamo import torch.fx from torch import nn, Tensor +from torch._dynamo.utils import is_compile_supported from torch.jit.annotations import BroadcastingList2 from torch.nn.modules.utils import _pair from torchvision.extension import _assert_has_ops, _has_ops @@ -12,6 +14,24 @@ from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format +def lazy_compile(**compile_kwargs): + """Lazily wrap a function with torch.compile on the first call + + This avoids eagerly importing dynamo. + """ + + def decorate_fn(fn): + @functools.wraps(fn) + def compile_hook(*args, **kwargs): + compiled_fn = torch.compile(fn, **compile_kwargs) + globals()[fn.__name__] = functools.wraps(fn)(compiled_fn) + return compiled_fn(*args, **kwargs) + + return compile_hook + + return decorate_fn + + # NB: all inputs are tensors def _bilinear_interpolate( input, # [N, C, H, W] @@ -86,15 +106,13 @@ def maybe_cast(tensor): return tensor -# This is a slow but pure Python and differentiable implementation of -# roi_align. It potentially is a good basis for Inductor compilation -# (but I have not benchmarked it) but today it is solely used for the -# fact that its backwards can be implemented deterministically, -# which is needed for the PT2 benchmark suite. -# +# This is a pure Python and differentiable implementation of roi_align. When +# run in eager mode, it uses a lot of memory, but when compiled it has +# acceptable memory usage. The main point of this implementation is that +# its backwards is deterministic. # It is transcribed directly off of the roi_align CUDA kernel, see # https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266 -@torch._dynamo.allow_in_graph +@lazy_compile(dynamic=True) def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned): orig_dtype = input.dtype @@ -232,7 +250,9 @@ def roi_align( if not isinstance(rois, torch.Tensor): rois = convert_boxes_to_roi_format(rois) if not torch.jit.is_scripting(): - if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)): + if ( + not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)) + ) and is_compile_supported(input.device.type): return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned) _assert_has_ops() return torch.ops.torchvision.roi_align( diff --git a/torchvision/transforms/_functional_tensor.py b/torchvision/transforms/_functional_tensor.py index 88dc9ca21cc..348f01bb1e6 100644 --- a/torchvision/transforms/_functional_tensor.py +++ b/torchvision/transforms/_functional_tensor.py @@ -722,10 +722,10 @@ def perspective( return _apply_grid_transform(img, grid, interpolation, fill=fill) -def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: +def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor: ksize_half = (kernel_size - 1) * 0.5 - x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size) + x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device) pdf = torch.exp(-0.5 * (x / sigma).pow(2)) kernel1d = pdf / pdf.sum() @@ -735,8 +735,8 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor: def _get_gaussian_kernel2d( kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device ) -> Tensor: - kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype) - kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype) + kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device) + kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device) kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :]) return kernel2d