diff --git a/test/test_transforms_video.py b/test/test_transforms_video.py new file mode 100644 index 00000000000..b0a237e9318 --- /dev/null +++ b/test/test_transforms_video.py @@ -0,0 +1,171 @@ +from __future__ import division +import torch +import torchvision.transforms as transforms +import unittest +import random +import numpy as np + +try: + from scipy import stats +except ImportError: + stats = None + + +class Tester(unittest.TestCase): + + def test_random_crop_video(self): + numFrames = random.randint(4, 128) + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) + result = transforms.Compose([ + transforms.ToTensorVideo(), + transforms.RandomCropVideo((oheight, owidth)), + ])(clip) + assert result.size(2) == oheight + assert result.size(3) == owidth + + transforms.RandomCropVideo((oheight, owidth)).__repr__() + + def test_random_resized_crop_video(self): + numFrames = random.randint(4, 128) + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) + result = transforms.Compose([ + transforms.ToTensorVideo(), + transforms.RandomResizedCropVideo((oheight, owidth)), + ])(clip) + assert result.size(2) == oheight + assert result.size(3) == owidth + + transforms.RandomResizedCropVideo((oheight, owidth)).__repr__() + + def test_center_crop_video(self): + numFrames = random.randint(4, 128) + height = random.randint(10, 32) * 2 + width = random.randint(10, 32) * 2 + oheight = random.randint(5, (height - 2) / 2) * 2 + owidth = random.randint(5, (width - 2) / 2) * 2 + + clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255 + oh1 = (height - oheight) // 2 + ow1 = (width - owidth) // 2 + clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :] + clipNarrow.fill_(0) + result = transforms.Compose([ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ])(clip) + + msg = "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + self.assertEqual(result.sum().item(), 0, msg) + + oheight += 1 + owidth += 1 + result = transforms.Compose([ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ])(clip) + sum1 = result.sum() + + msg = "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + self.assertEqual(sum1.item() > 1, True, msg) + + oheight += 1 + owidth += 1 + result = transforms.Compose([ + transforms.ToTensorVideo(), + transforms.CenterCropVideo((oheight, owidth)), + ])(clip) + sum2 = result.sum() + + msg = "height: " + str(height) + " width: " \ + + str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) + self.assertTrue(sum2.item() > 1, msg) + self.assertTrue(sum2.item() > sum1.item(), msg) + + @unittest.skipIf(stats is None, 'scipy.stats is not available') + def test_normalize_video(self): + def samples_from_standard_normal(tensor): + p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue + return p_value > 0.0001 + + random_state = random.getstate() + random.seed(42) + for channels in [1, 3]: + numFrames = random.randint(4, 128) + height = random.randint(32, 256) + width = random.randint(32, 256) + mean = random.random() + std = random.random() + clip = torch.normal(mean, std, size=(channels, numFrames, height, width)) + mean = [clip[c].mean().item() for c in range(channels)] + std = [clip[c].std().item() for c in range(channels)] + normalized = transforms.NormalizeVideo(mean, std)(clip) + assert samples_from_standard_normal(normalized) + random.setstate(random_state) + + # Checking the optional in-place behaviour + tensor = torch.rand((3, 128, 16, 16)) + tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor) + assert torch.equal(tensor, tensor_inplace) + + transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__() + + def test_to_tensor_video(self): + numFrames, height, width = 64, 4, 4 + trans = transforms.ToTensorVideo() + + with self.assertRaises(TypeError): + trans(np.random.rand(numFrames, height, width, 1).tolist()) + trans(torch.rand((numFrames, height, width, 1), dtype=torch.float)) + + with self.assertRaises(ValueError): + trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8)) + trans(torch.ones((height, width, 3), dtype=torch.uint8)) + trans(torch.ones((width, 3), dtype=torch.uint8)) + trans(torch.ones((3), dtype=torch.uint8)) + + trans.__repr__() + + @unittest.skipIf(stats is None, 'scipy.stats not available') + def test_random_horizontal_flip_video(self): + random_state = random.getstate() + random.seed(42) + clip = torch.rand((3, 4, 112, 112), dtype=torch.float) + hclip = clip.flip((-1)) + + num_samples = 250 + num_horizontal = 0 + for _ in range(num_samples): + out = transforms.RandomHorizontalFlipVideo()(clip) + if torch.all(torch.eq(out, hclip)): + num_horizontal += 1 + + p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) + random.setstate(random_state) + assert p_value > 0.0001 + + num_samples = 250 + num_horizontal = 0 + for _ in range(num_samples): + out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip) + if torch.all(torch.eq(out, hclip)): + num_horizontal += 1 + + p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) + random.setstate(random_state) + assert p_value > 0.0001 + + transforms.RandomHorizontalFlipVideo().__repr__() + + +if __name__ == '__main__': + unittest.main() diff --git a/torchvision/transforms/__init__.py b/torchvision/transforms/__init__.py index 7986cdd6429..175a8a8dc1b 100644 --- a/torchvision/transforms/__init__.py +++ b/torchvision/transforms/__init__.py @@ -1 +1,2 @@ from .transforms import * +from .transforms_video import * diff --git a/torchvision/transforms/functional_video.py b/torchvision/transforms/functional_video.py new file mode 100644 index 00000000000..06c30716908 --- /dev/null +++ b/torchvision/transforms/functional_video.py @@ -0,0 +1,101 @@ +import torch + + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tesnor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def crop(clip, i, j, h, w): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + """ + assert len(clip.size()) == 4, "clip should be a 4D tensor" + return clip[..., i:i + h, j:j + w] + + +def resize(clip, target_size, interpolation_mode): + assert len(target_size) == 2, "target size should be tuple (height, width)" + return torch.nn.functional.interpolate( + clip, size=target_size, mode=interpolation_mode + ) + + +def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): + """ + Do spatial cropping and resizing to the video clip + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + i (int): i in (i,j) i.e coordinates of the upper left corner. + j (int): j in (i,j) i.e coordinates of the upper left corner. + h (int): Height of the cropped region. + w (int): Width of the cropped region. + size (tuple(int, int)): height and width of resized clip + Returns: + clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) + """ + assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + clip = crop(clip, i, j, h, w) + clip = resize(clip, size, interpolation_mode) + return clip + + +def center_crop(clip, crop_size): + assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + h, w = clip.size(-2), clip.size(-1) + th, tw = crop_size + assert h >= th and w >= tw, "height and width must be no smaller than crop_size" + + i = int(round((h - th) / 2.0)) + j = int(round((w - tw) / 2.0)) + return crop(clip, i, j, th, tw) + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimenions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + return clip.float().permute(3, 0, 1, 2) / 255.0 + + +def normalize(clip, mean, std, inplace=False): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + mean (tuple): pixel RGB mean. Size is (3) + std (tuple): pixel standard deviation. Size is (3) + Returns: + normalized clip (torch.tensor): Size is (C, T, H, W) + """ + assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + if not inplace: + clip = clip.clone() + mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) + std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) + clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) + return clip + + +def hflip(clip): + """ + Args: + clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) + Returns: + flipped clip (torch.tensor): Size is (C, T, H, W) + """ + assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" + return clip.flip((-1)) diff --git a/torchvision/transforms/transforms.py b/torchvision/transforms/transforms.py index b21a6d86eef..3ec84aae84c 100644 --- a/torchvision/transforms/transforms.py +++ b/torchvision/transforms/transforms.py @@ -40,6 +40,15 @@ } +def _get_image_size(img): + if F._is_pil_image(img): + return img.size + elif isinstance(img, torch.Tensor) and img.dim() > 2: + return img.shape[-2:][::-1] + else: + raise TypeError("Unexpected type {}".format(type(img))) + + class Compose(object): """Composes several transforms together. @@ -444,7 +453,7 @@ def get_params(img, output_size): Returns: tuple: params (i, j, h, w) to be passed to ``crop`` for random crop. """ - w, h = img.size + w, h = _get_image_size(img) th, tw = output_size if w == tw and h == th: return 0, 0, h, w @@ -635,7 +644,8 @@ def get_params(img, scale, ratio): tuple: params (i, j, h, w) to be passed to ``crop`` for a random sized crop. """ - area = img.size[0] * img.size[1] + width, height = _get_image_size(img) + area = height * width for attempt in range(10): target_area = random.uniform(*scale) * area @@ -645,24 +655,24 @@ def get_params(img, scale, ratio): w = int(round(math.sqrt(target_area * aspect_ratio))) h = int(round(math.sqrt(target_area / aspect_ratio))) - if 0 < w <= img.size[0] and 0 < h <= img.size[1]: - i = random.randint(0, img.size[1] - h) - j = random.randint(0, img.size[0] - w) + if 0 < w <= width and 0 < h <= height: + i = random.randint(0, height - h) + j = random.randint(0, width - w) return i, j, h, w # Fallback to central crop - in_ratio = img.size[0] / img.size[1] + in_ratio = float(width) / float(height) if (in_ratio < min(ratio)): - w = img.size[0] + w = width h = int(round(w / min(ratio))) elif (in_ratio > max(ratio)): - h = img.size[1] + h = height w = int(round(h * max(ratio))) else: # whole image - w = img.size[0] - h = img.size[1] - i = (img.size[1] - h) // 2 - j = (img.size[0] - w) // 2 + w = width + h = height + i = (height - h) // 2 + j = (width - w) // 2 return i, j, h, w def __call__(self, img): diff --git a/torchvision/transforms/transforms_video.py b/torchvision/transforms/transforms_video.py new file mode 100644 index 00000000000..e11d8489eb8 --- /dev/null +++ b/torchvision/transforms/transforms_video.py @@ -0,0 +1,173 @@ +#!/usr/bin/env python3 + +import numbers +import random + +from torchvision.transforms import ( + RandomCrop, + RandomResizedCrop, +) + +from . import functional_video as F + + +__all__ = [ + "RandomCropVideo", + "RandomResizedCropVideo", + "CenterCropVideo", + "NormalizeVideo", + "ToTensorVideo", + "RandomHorizontalFlipVideo", +] + + +class RandomCropVideo(RandomCrop): + def __init__(self, size): + if isinstance(size, numbers.Number): + self.size = (int(size), int(size)) + else: + self.size = size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: randomly cropped/resized video clip. + size is (C, T, OH, OW) + """ + i, j, h, w = self.get_params(clip, self.size) + return F.crop(clip, i, j, h, w) + + def __repr__(self): + return self.__class__.__name__ + '(size={0})'.format(self.size) + + +class RandomResizedCropVideo(RandomResizedCrop): + def __init__( + self, + size, + scale=(0.08, 1.0), + ratio=(3.0 / 4.0, 4.0 / 3.0), + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + assert len(size) == 2, "size should be tuple (height, width)" + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + self.scale = scale + self.ratio = ratio + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: randomly cropped/resized video clip. + size is (C, T, H, W) + """ + i, j, h, w = self.get_params(clip, self.scale, self.ratio) + return F.resized_crop(clip, i, j, h, w, self.size, self.interpolation_mode) + + def __repr__(self): + return self.__class__.__name__ + \ + '(size={0}, interpolation_mode={1}, scale={2}, ratio={3})'.format( + self.size, self.interpolation_mode, self.scale, self.ratio + ) + + +class CenterCropVideo(object): + def __init__(self, crop_size): + if isinstance(crop_size, numbers.Number): + self.crop_size = (int(crop_size), int(crop_size)) + else: + self.crop_size = crop_size + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) + Returns: + torch.tensor: central cropping of video clip. Size is + (C, T, crop_size, crop_size) + """ + return F.center_crop(clip, self.crop_size) + + def __repr__(self): + return self.__class__.__name__ + '(crop_size={0})'.format(self.crop_size) + + +class NormalizeVideo(object): + """ + Normalize the video clip by mean subtraction and division by standard deviation + Args: + mean (3-tuple): pixel RGB mean + std (3-tuple): pixel RGB standard deviation + inplace (boolean): whether do in-place normalization + """ + + def __init__(self, mean, std, inplace=False): + self.mean = mean + self.std = std + self.inplace = inplace + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): video clip to be normalized. Size is (C, T, H, W) + """ + return F.normalize(clip, self.mean, self.std, self.inplace) + + def __repr__(self): + return self.__class__.__name__ + '(mean={0}, std={1}, inplace={2})'.format( + self.mean, self.std, self.inplace) + + +class ToTensorVideo(object): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimenions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) + Return: + clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) + """ + return F.to_tensor(clip) + + def __repr__(self): + return self.__class__.__name__ + + +class RandomHorizontalFlipVideo(object): + """ + Flip the video clip along the horizonal direction with a given probability + Args: + p (float): probability of the clip being flipped. Default value is 0.5 + """ + + def __init__(self, p=0.5): + self.p = p + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Size is (C, T, H, W) + Return: + clip (torch.tensor): Size is (C, T, H, W) + """ + if random.random() < self.p: + clip = F.hflip(clip) + return clip + + def __repr__(self): + return self.__class__.__name__ + "(p={0})".format(self.p)