From 337fa343c40a97cb76c266243035a5877ce7e955 Mon Sep 17 00:00:00 2001 From: haarisr <122410226+haarisr@users.noreply.github.com> Date: Tue, 28 May 2024 03:12:25 -0700 Subject: [PATCH 1/3] Allow K=1 in `draw_keypoints` (#8439) Co-authored-by: Nicolas Hug --- test/test_utils.py | 7 +++++++ torchvision/utils.py | 8 ++++---- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/test/test_utils.py b/test/test_utils.py index ac394b51d63..e89bef4a6d9 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -355,6 +355,13 @@ def test_draw_keypoints_vanilla(): assert_equal(img, img_cp) +def test_draw_keypoins_K_equals_one(): + # Non-regression test for https://github.com/pytorch/vision/pull/8439 + img = torch.full((3, 100, 100), 0, dtype=torch.uint8) + keypoints = torch.tensor([[[10, 10]]], dtype=torch.float) + utils.draw_keypoints(img, keypoints) + + @pytest.mark.parametrize("colors", ["red", "#FF00FF", (1, 34, 122)]) def test_draw_keypoints_colored(colors): # Keypoints is declared on top as global variable diff --git a/torchvision/utils.py b/torchvision/utils.py index 94b3ec65c87..6b2d19ec3dd 100644 --- a/torchvision/utils.py +++ b/torchvision/utils.py @@ -392,10 +392,10 @@ def draw_keypoints( # validate visibility if visibility is None: # set default visibility = torch.ones(keypoints.shape[:-1], dtype=torch.bool) - # If the last dimension is 1, e.g., after calling split([2, 1], dim=-1) on the output of a keypoint-prediction - # model, make sure visibility has shape (num_instances, K). - # Iff K = 1, this has unwanted behavior, but K=1 does not really make sense in the first place. - visibility = visibility.squeeze(-1) + if visibility.ndim == 3: + # If visibility was passed as pred.split([2, 1], dim=-1), it will be of shape (num_instances, K, 1). + # We make sure it is of shape (num_instances, K). This isn't documented, we're just being nice. + visibility = visibility.squeeze(-1) if visibility.ndim != 2: raise ValueError(f"visibility must be of shape (num_instances, K). Got ndim={visibility.ndim}") if visibility.shape != keypoints.shape[:-1]: From 778ce48b27cb35390bcb5c5971a8dc9e6aba0d75 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 28 May 2024 13:36:47 +0100 Subject: [PATCH 2/3] Remove x86 MacOS jobs and use M1 instead (#8446) --- .github/scripts/setup-env.sh | 11 ----------- .github/workflows/build-cmake.yml | 1 - .github/workflows/tests.yml | 7 +------ 3 files changed, 1 insertion(+), 18 deletions(-) diff --git a/.github/scripts/setup-env.sh b/.github/scripts/setup-env.sh index a4f113c367f..26a607558d3 100755 --- a/.github/scripts/setup-env.sh +++ b/.github/scripts/setup-env.sh @@ -22,17 +22,6 @@ case $(uname) in ;; esac -if [[ "${OS_TYPE}" == "macos" && $(uname -m) == x86_64 ]]; then - echo '::group::Uninstall system JPEG libraries on macOS' - # The x86 macOS runners, e.g. the GitHub Actions native "macos-12" runner, has some JPEG and PNG libraries - # installed by default that interfere with our build. We uninstall them here and use the one from conda below. - IMAGE_LIBS=$(brew list | grep -E "jpeg|png") - for lib in $IMAGE_LIBS; do - brew uninstall --ignore-dependencies --force "${lib}" - done - echo '::endgroup::' -fi - echo '::group::Create build environment' # See https://github.com/pytorch/vision/issues/7296 for ffmpeg conda create \ diff --git a/.github/workflows/build-cmake.yml b/.github/workflows/build-cmake.yml index 107583235ed..1dce7b8446a 100644 --- a/.github/workflows/build-cmake.yml +++ b/.github/workflows/build-cmake.yml @@ -40,7 +40,6 @@ jobs: strategy: matrix: include: - - runner: macos-12 - runner: macos-m1-stable fail-fast: false uses: pytorch/test-infra/.github/workflows/macos_job.yml@main diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 9cfc6be9d5e..ad327129912 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -53,16 +53,11 @@ jobs: - "3.10" - "3.11" - "3.12" - runner: ["macos-12"] - include: - - python-version: "3.8" - runner: macos-m1-stable + runner: ["macos-m1-stable"] fail-fast: false uses: pytorch/test-infra/.github/workflows/macos_job.yml@main with: repository: pytorch/vision - # We need an increased timeout here, since the macos-12 runner is the free one from GH - # and needs roughly 2 hours to just run the test suite timeout: 240 runner: ${{ matrix.runner }} test-infra-ref: main From c585a515fba2c34d5fbef221fccc0902db588131 Mon Sep 17 00:00:00 2001 From: Mahdi Lamb Date: Tue, 28 May 2024 13:47:58 +0100 Subject: [PATCH 3/3] Enable one-hot-encoded labels in MixUp and CutMix (#8427) Co-authored-by: Nicolas Hug --- test/test_transforms_v2.py | 35 +++++++++++++++------------ torchvision/transforms/v2/_augment.py | 28 +++++++++++++++------ 2 files changed, 40 insertions(+), 23 deletions(-) diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 24574eb1a43..07235333af4 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -2169,26 +2169,30 @@ def test_image_correctness(self, brightness_factor): class TestCutMixMixUp: class DummyDataset: - def __init__(self, size, num_classes): + def __init__(self, size, num_classes, one_hot_labels): self.size = size self.num_classes = num_classes + self.one_hot_labels = one_hot_labels assert size < num_classes def __getitem__(self, idx): img = torch.rand(3, 100, 100) label = idx # This ensures all labels in a batch are unique and makes testing easier + if self.one_hot_labels: + label = torch.nn.functional.one_hot(torch.tensor(label), num_classes=self.num_classes) return img, label def __len__(self): return self.size @pytest.mark.parametrize("T", [transforms.CutMix, transforms.MixUp]) - def test_supported_input_structure(self, T): + @pytest.mark.parametrize("one_hot_labels", (True, False)) + def test_supported_input_structure(self, T, one_hot_labels): batch_size = 32 num_classes = 100 - dataset = self.DummyDataset(size=batch_size, num_classes=num_classes) + dataset = self.DummyDataset(size=batch_size, num_classes=num_classes, one_hot_labels=one_hot_labels) cutmix_mixup = T(num_classes=num_classes) @@ -2198,7 +2202,7 @@ def test_supported_input_structure(self, T): img, target = next(iter(dl)) input_img_size = img.shape[-3:] assert isinstance(img, torch.Tensor) and isinstance(target, torch.Tensor) - assert target.shape == (batch_size,) + assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,) def check_output(img, target): assert img.shape == (batch_size, *input_img_size) @@ -2209,7 +2213,7 @@ def check_output(img, target): # After Dataloader, as unpacked input img, target = next(iter(dl)) - assert target.shape == (batch_size,) + assert target.shape == (batch_size, num_classes) if one_hot_labels else (batch_size,) img, target = cutmix_mixup(img, target) check_output(img, target) @@ -2264,7 +2268,7 @@ def test_error(self, T): with pytest.raises(ValueError, match="Could not infer where the labels are"): cutmix_mixup({"img": imgs, "Nothing_else": 3}) - with pytest.raises(ValueError, match="labels tensor should be of shape"): + with pytest.raises(ValueError, match="labels should be index based"): # Note: the error message isn't ideal, but that's because the label heuristic found the img as the label # It's OK, it's an edge-case. The important thing is that this fails loudly instead of passing silently cutmix_mixup(imgs) @@ -2272,22 +2276,21 @@ def test_error(self, T): with pytest.raises(ValueError, match="When using the default labels_getter"): cutmix_mixup(imgs, "not_a_tensor") - with pytest.raises(ValueError, match="labels tensor should be of shape"): - cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3))) - with pytest.raises(ValueError, match="Expected a batched input with 4 dims"): cutmix_mixup(imgs[None, None], torch.randint(0, num_classes, size=(batch_size,))) with pytest.raises(ValueError, match="does not match the batch size of the labels"): cutmix_mixup(imgs, torch.randint(0, num_classes, size=(batch_size + 1,))) - with pytest.raises(ValueError, match="labels tensor should be of shape"): - # The purpose of this check is more about documenting the current - # behaviour of what happens on a Compose(), rather than actually - # asserting the expected behaviour. We may support Compose() in the - # future, e.g. for 2 consecutive CutMix? - labels = torch.randint(0, num_classes, size=(batch_size,)) - transforms.Compose([cutmix_mixup, cutmix_mixup])(imgs, labels) + with pytest.raises(ValueError, match="When passing 2D labels"): + wrong_num_classes = num_classes + 1 + T(alpha=0.5, num_classes=num_classes)(imgs, torch.randint(0, 2, size=(batch_size, wrong_num_classes))) + + with pytest.raises(ValueError, match="but got a tensor of shape"): + cutmix_mixup(imgs, torch.randint(0, 2, size=(2, 3, 4))) + + with pytest.raises(ValueError, match="num_classes must be passed"): + T(alpha=0.5)(imgs, torch.randint(0, num_classes, size=(batch_size,))) @pytest.mark.parametrize("key", ("labels", "LABELS", "LaBeL", "SOME_WEIRD_KEY_THAT_HAS_LABeL_IN_IT")) diff --git a/torchvision/transforms/v2/_augment.py b/torchvision/transforms/v2/_augment.py index cc645d6c8a8..f085ef3ca6e 100644 --- a/torchvision/transforms/v2/_augment.py +++ b/torchvision/transforms/v2/_augment.py @@ -1,7 +1,7 @@ import math import numbers import warnings -from typing import Any, Callable, Dict, List, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union import PIL.Image import torch @@ -142,7 +142,7 @@ def _transform(self, inpt: Any, params: Dict[str, Any]) -> Any: class _BaseMixUpCutMix(Transform): - def __init__(self, *, alpha: float = 1.0, num_classes: int, labels_getter="default") -> None: + def __init__(self, *, alpha: float = 1.0, num_classes: Optional[int] = None, labels_getter="default") -> None: super().__init__() self.alpha = float(alpha) self._dist = torch.distributions.Beta(torch.tensor([alpha]), torch.tensor([alpha])) @@ -162,10 +162,21 @@ def forward(self, *inputs): labels = self._labels_getter(inputs) if not isinstance(labels, torch.Tensor): raise ValueError(f"The labels must be a tensor, but got {type(labels)} instead.") - elif labels.ndim != 1: + if labels.ndim not in (1, 2): raise ValueError( - f"labels tensor should be of shape (batch_size,) " f"but got shape {labels.shape} instead." + f"labels should be index based with shape (batch_size,) " + f"or probability based with shape (batch_size, num_classes), " + f"but got a tensor of shape {labels.shape} instead." ) + if labels.ndim == 2 and self.num_classes is not None and labels.shape[-1] != self.num_classes: + raise ValueError( + f"When passing 2D labels, " + f"the number of elements in last dimension must match num_classes: " + f"{labels.shape[-1]} != {self.num_classes}. " + f"You can Leave num_classes to None." + ) + if labels.ndim == 1 and self.num_classes is None: + raise ValueError("num_classes must be passed if the labels are index-based (1D)") params = { "labels": labels, @@ -198,7 +209,8 @@ def _check_image_or_video(self, inpt: torch.Tensor, *, batch_size: int): ) def _mixup_label(self, label: torch.Tensor, *, lam: float) -> torch.Tensor: - label = one_hot(label, num_classes=self.num_classes) + if label.ndim == 1: + label = one_hot(label, num_classes=self.num_classes) # type: ignore[arg-type] if not label.dtype.is_floating_point: label = label.float() return label.roll(1, 0).mul_(1.0 - lam).add_(label.mul(lam)) @@ -223,7 +235,8 @@ class MixUp(_BaseMixUpCutMix): Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. - num_classes (int): number of classes in the batch. Used for one-hot-encoding. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``MixUp()(imgs_batch, labels_batch)``. @@ -271,7 +284,8 @@ class CutMix(_BaseMixUpCutMix): Args: alpha (float, optional): hyperparameter of the Beta distribution used for mixup. Default is 1. - num_classes (int): number of classes in the batch. Used for one-hot-encoding. + num_classes (int, optional): number of classes in the batch. Used for one-hot-encoding. + Can be None only if the labels are already one-hot-encoded. labels_getter (callable or "default", optional): indicates how to identify the labels in the input. By default, this will pick the second parameter as the labels if it's a tensor. This covers the most common scenario where this transform is called as ``CutMix()(imgs_batch, labels_batch)``.