diff --git a/test/datasets_utils.py b/test/datasets_utils.py index bd9f7ea3a0f..43b4103646a 100644 --- a/test/datasets_utils.py +++ b/test/datasets_utils.py @@ -27,7 +27,11 @@ import torchvision.io from common_utils import disable_console_output, get_tmp_dir from torch.utils._pytree import tree_any +from torch.utils.data import DataLoader +from torchvision import tv_tensors +from torchvision.datasets import wrap_dataset_for_transforms_v2 from torchvision.transforms.functional import get_dimensions +from torchvision.transforms.v2.functional import get_size __all__ = [ @@ -568,9 +572,6 @@ def test_transforms(self, config): @test_all_configs def test_transforms_v2_wrapper(self, config): - from torchvision import tv_tensors - from torchvision.datasets import wrap_dataset_for_transforms_v2 - try: with self.create_dataset(config) as (dataset, info): for target_keys in [None, "all"]: @@ -709,26 +710,29 @@ def _no_collate(batch): return batch -def check_transforms_v2_wrapper_spawn(dataset): - # On Linux and Windows, the DataLoader forks the main process by default. This is not available on macOS, so new - # subprocesses are spawned. This requires the whole pipeline including the dataset to be pickleable, which is what - # we are enforcing here. - if platform.system() != "Darwin": - pytest.skip("Multiprocessing spawning is only checked on macOS.") +def check_transforms_v2_wrapper_spawn(dataset, expected_size): + # This check ensures that the wrapped datasets can be used with multiprocessing_context="spawn" in the DataLoader. + # We also check that transforms are applied correctly as a non-regression test for + # https://github.com/pytorch/vision/issues/8066 + # Implicitly, this also checks that the wrapped datasets are pickleable. - from torch.utils.data import DataLoader - from torchvision import tv_tensors - from torchvision.datasets import wrap_dataset_for_transforms_v2 + # To save CI/test time, we only check on Windows where "spawn" is the default + if platform.system() != "Windows": + pytest.skip("Multiprocessing spawning is only checked on macOS.") wrapped_dataset = wrap_dataset_for_transforms_v2(dataset) dataloader = DataLoader(wrapped_dataset, num_workers=2, multiprocessing_context="spawn", collate_fn=_no_collate) - for wrapped_sample in dataloader: - assert tree_any( - lambda item: isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)), wrapped_sample + def resize_was_applied(item): + # Checking the size of the output ensures that the Resize transform was correctly applied + return isinstance(item, (tv_tensors.Image, tv_tensors.Video, PIL.Image.Image)) and get_size(item) == list( + expected_size ) + for wrapped_sample in dataloader: + assert tree_any(resize_was_applied, wrapped_sample) + def create_image_or_video_tensor(size: Sequence[int]) -> torch.Tensor: r"""Create a random uint8 tensor. diff --git a/test/test_datasets.py b/test/test_datasets.py index 1270201d53e..832aefe5e09 100644 --- a/test/test_datasets.py +++ b/test/test_datasets.py @@ -24,6 +24,7 @@ import torch.nn.functional as F from common_utils import combinations_grid from torchvision import datasets +from torchvision.transforms import v2 class STL10TestCase(datasets_utils.ImageDatasetTestCase): @@ -184,8 +185,9 @@ def test_combined_targets(self): f"{actual} is not {expected}", def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset(target_type="category") as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(target_type="category", transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class Caltech256TestCase(datasets_utils.ImageDatasetTestCase): @@ -263,8 +265,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["split"]] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CityScapesTestCase(datasets_utils.ImageDatasetTestCase): @@ -391,9 +394,10 @@ def test_feature_types_target_polygon(self): (polygon_target, info["expected_polygon_target"]) def test_transforms_v2_wrapper_spawn(self): + expected_size = (123, 321) for target_type in ["instance", "semantic", ["instance", "semantic"]]: - with self.create_dataset(target_type=target_type) as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class ImageNetTestCase(datasets_utils.ImageDatasetTestCase): @@ -427,8 +431,9 @@ def inject_fake_data(self, tmpdir, config): return num_examples def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CIFAR10TestCase(datasets_utils.ImageDatasetTestCase): @@ -625,9 +630,10 @@ def test_images_names_split(self): assert merged_imgs_names == all_imgs_names def test_transforms_v2_wrapper_spawn(self): + expected_size = (123, 321) for target_type in ["identity", "bbox", ["identity", "bbox"]]: - with self.create_dataset(target_type=target_type) as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + with self.create_dataset(target_type=target_type, transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class VOCSegmentationTestCase(datasets_utils.ImageDatasetTestCase): @@ -717,8 +723,9 @@ def add_bndbox(obj, bndbox=None): return data def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class VOCDetectionTestCase(VOCSegmentationTestCase): @@ -741,8 +748,9 @@ def test_annotations(self): assert object == info["annotation"] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CocoDetectionTestCase(datasets_utils.ImageDatasetTestCase): @@ -815,8 +823,9 @@ def _create_json(self, root, name, content): return file def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class CocoCaptionsTestCase(CocoDetectionTestCase): @@ -1005,9 +1014,11 @@ def inject_fake_data(self, tmpdir, config): ) return num_videos_per_class * len(classes) + @pytest.mark.xfail(reason="FIXME") def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset(output_format="TCHW") as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(output_format="TCHW", transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class HMDB51TestCase(datasets_utils.VideoDatasetTestCase): @@ -1237,8 +1248,9 @@ def _file_stem(self, idx): return f"2008_{idx:06d}" def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset(mode="segmentation") as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(mode="segmentation", transforms=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class FakeDataTestCase(datasets_utils.ImageDatasetTestCase): @@ -1690,8 +1702,9 @@ def inject_fake_data(self, tmpdir, config): return split_to_num_examples[config["train"]] def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class SvhnTestCase(datasets_utils.ImageDatasetTestCase): @@ -2568,8 +2581,9 @@ def _meta_to_split_and_classification_ann(self, meta, idx): return (image_id, class_id, species, breed_id) def test_transforms_v2_wrapper_spawn(self): - with self.create_dataset() as (dataset, _): - datasets_utils.check_transforms_v2_wrapper_spawn(dataset) + expected_size = (123, 321) + with self.create_dataset(transform=v2.Resize(size=expected_size)) as (dataset, _): + datasets_utils.check_transforms_v2_wrapper_spawn(dataset, expected_size=expected_size) class StanfordCarsTestCase(datasets_utils.ImageDatasetTestCase): diff --git a/torchvision/tv_tensors/_dataset_wrapper.py b/torchvision/tv_tensors/_dataset_wrapper.py index ef9260ebde9..04c3bf7133d 100644 --- a/torchvision/tv_tensors/_dataset_wrapper.py +++ b/torchvision/tv_tensors/_dataset_wrapper.py @@ -6,6 +6,7 @@ import contextlib from collections import defaultdict +from copy import copy import torch @@ -198,8 +199,19 @@ def __getitem__(self, idx): def __len__(self): return len(self._dataset) + # TODO: maybe we should use __getstate__ and __setstate__ instead of __reduce__, as recommended in the docs. def __reduce__(self): - return wrap_dataset_for_transforms_v2, (self._dataset, self._target_keys) + # __reduce__ gets called when we try to pickle the dataset. + # In a DataLoader with spawn context, this gets called `num_workers` times from the main process. + + # We have to reset the [target_]transform[s] attributes of the dataset + # to their original values, because we previously set them to None in __init__(). + dataset = copy(self._dataset) + dataset.transform = self.transform + dataset.transforms = self.transforms + dataset.target_transform = self.target_transform + + return wrap_dataset_for_transforms_v2, (dataset, self._target_keys) def raise_not_supported(description):