diff --git a/test/test_transforms_v2.py b/test/test_transforms_v2.py index 831a7e3b570..b40d04fffdd 100644 --- a/test/test_transforms_v2.py +++ b/test/test_transforms_v2.py @@ -4935,15 +4935,24 @@ def test_transform(self, transform, make_input): check_transform(transform, make_input()) @pytest.mark.parametrize("num_output_channels", [1, 3]) + @pytest.mark.parametrize("color_space", ["RGB", "GRAY"]) @pytest.mark.parametrize("fn", [F.rgb_to_grayscale, transform_cls_to_functional(transforms.Grayscale)]) - def test_image_correctness(self, num_output_channels, fn): - image = make_image(dtype=torch.uint8, device="cpu") + def test_image_correctness(self, num_output_channels, color_space, fn): + image = make_image(dtype=torch.uint8, device="cpu", color_space=color_space) actual = fn(image, num_output_channels=num_output_channels) expected = F.to_image(F.rgb_to_grayscale(F.to_pil_image(image), num_output_channels=num_output_channels)) assert_equal(actual, expected, rtol=0, atol=1) + def test_expanded_channels_are_not_views_into_the_same_underlying_tensor(self): + image = make_image(dtype=torch.uint8, device="cpu", color_space="GRAY") + + output_image = F.rgb_to_grayscale(image, num_output_channels=3) + assert_equal(output_image[0][0][0], output_image[1][0][0]) + output_image[0][0][0] = output_image[0][0][0] + 1 + assert output_image[0][0][0] != output_image[1][0][0] + @pytest.mark.parametrize("num_input_channels", [1, 3]) def test_random_transform_correctness(self, num_input_channels): image = make_image( diff --git a/torchvision/transforms/v2/functional/_color.py b/torchvision/transforms/v2/functional/_color.py index b0189fd95ef..2b9c1e738ca 100644 --- a/torchvision/transforms/v2/functional/_color.py +++ b/torchvision/transforms/v2/functional/_color.py @@ -33,9 +33,13 @@ def rgb_to_grayscale(inpt: torch.Tensor, num_output_channels: int = 1) -> torch. def _rgb_to_grayscale_image( image: torch.Tensor, num_output_channels: int = 1, preserve_dtype: bool = True ) -> torch.Tensor: - if image.shape[-3] == 1: + # TODO: Maybe move the validation that num_output_channels is 1 or 3 to this function instead of callers. + if image.shape[-3] == 1 and num_output_channels == 1: return image.clone() - + if image.shape[-3] == 1 and num_output_channels == 3: + s = [1] * len(image.shape) + s[-3] = 3 + return image.repeat(s) r, g, b = image.unbind(dim=-3) l_img = r.mul(0.2989).add_(g, alpha=0.587).add_(b, alpha=0.114) l_img = l_img.unsqueeze(dim=-3)