diff --git a/test/test_transforms.py b/test/test_transforms.py index f19c5480b02..978ab823c95 100644 --- a/test/test_transforms.py +++ b/test/test_transforms.py @@ -512,6 +512,49 @@ def test_accimage_to_tensor(self): self.assertEqual(expected_output.size(), output.size()) self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) + def test_pil_to_tensor(self): + test_channels = [1, 3, 4] + height, width = 4, 4 + trans = transforms.PILToTensor() + + with self.assertRaises(TypeError): + trans(np.random.rand(1, height, width).tolist()) + trans(np.random.rand(1, height, width)) + + for channels in test_channels: + input_data = torch.ByteTensor(channels, height, width).random_(0, 255) + img = transforms.ToPILImage()(input_data) + output = trans(img) + self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + + input_data = np.random.randint(low=0, high=255, size=(height, width, channels)).astype(np.uint8) + img = transforms.ToPILImage()(input_data) + output = trans(img) + expected_output = input_data.transpose((2, 0, 1)) + self.assertTrue(np.allclose(output.numpy(), expected_output)) + + input_data = torch.as_tensor(np.random.rand(channels, height, width).astype(np.float32)) + img = transforms.ToPILImage()(input_data) # CHW -> HWC and (* 255).byte() + output = trans(img) # HWC -> CHW + expected_output = (input_data * 255).byte() + self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) + + # separate test for mode '1' PIL images + input_data = torch.ByteTensor(1, height, width).bernoulli_() + img = transforms.ToPILImage()(input_data.mul(255)).convert('1') + output = trans(img) + self.assertTrue(np.allclose(input_data.numpy(), output.numpy())) + + @unittest.skipIf(accimage is None, 'accimage not available') + def test_accimage_pil_to_tensor(self): + trans = transforms.PILToTensor() + + expected_output = trans(Image.open(GRACE_HOPPER).convert('RGB')) + output = trans(accimage.Image(GRACE_HOPPER)) + + self.assertEqual(expected_output.size(), output.size()) + self.assertTrue(np.allclose(output.numpy(), expected_output.numpy())) + @unittest.skipIf(accimage is None, 'accimage not available') def test_accimage_resize(self): trans = transforms.Compose([ diff --git a/torchvision/transforms/functional.py b/torchvision/transforms/functional.py index 7ce1fb6ab36..bdfa6567a82 100644 --- a/torchvision/transforms/functional.py +++ b/torchvision/transforms/functional.py @@ -82,6 +82,33 @@ def to_tensor(pic): return img +def pil_to_tensor(pic): + """Convert a ``PIL Image`` to a tensor of the same type. + + See ``AsTensor`` for more details. + + Args: + pic (PIL Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + if not(_is_pil_image(pic)): + raise TypeError('pic should be PIL Image. Got {}'.format(type(pic))) + + if accimage is not None and isinstance(pic, accimage.Image): + nppic = np.zeros([pic.channels, pic.height, pic.width], dtype=np.float32) + pic.copyto(nppic) + return torch.as_tensor(nppic) + + # handle PIL Image + img = torch.as_tensor(np.asarray(pic)) + img = img.view(pic.size[1], pic.size[0], len(pic.getbands())) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)) + return img + + def to_pil_image(pic, mode=None): """Convert a tensor or an ndarray to PIL Image. diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index 10783c8e53d..eb49b99be93 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -15,7 +15,7 @@ from . import functional as F -__all__ = ["Compose", "ToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", +__all__ = ["Compose", "ToTensor", "PILToTensor", "ToPILImage", "Normalize", "Resize", "Scale", "CenterCrop", "Pad", "Lambda", "RandomApply", "RandomChoice", "RandomOrder", "RandomCrop", "RandomHorizontalFlip", "RandomVerticalFlip", "RandomResizedCrop", "RandomSizedCrop", "FiveCrop", "TenCrop", "LinearTransformation", "ColorJitter", "RandomRotation", "RandomAffine", "Grayscale", "RandomGrayscale", @@ -95,6 +95,26 @@ def __repr__(self): return self.__class__.__name__ + '()' +class PILToTensor(object): + """Convert a ``PIL Image`` to a tensor of the same type. + + Converts a PIL Image (H x W x C) to a torch.Tensor of shape (C x H x W). + """ + + def __call__(self, pic): + """ + Args: + pic (PIL Image): Image to be converted to tensor. + + Returns: + Tensor: Converted image. + """ + return F.pil_to_tensor(pic) + + def __repr__(self): + return self.__class__.__name__ + '()' + + class ToPILImage(object): """Convert a tensor or an ndarray to PIL Image.