-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Video transforms #1353
Merged
Merged
Video transforms #1353
Changes from all commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,171 @@ | ||
from __future__ import division | ||
import torch | ||
import torchvision.transforms as transforms | ||
import unittest | ||
import random | ||
import numpy as np | ||
|
||
try: | ||
from scipy import stats | ||
except ImportError: | ||
stats = None | ||
|
||
|
||
class Tester(unittest.TestCase): | ||
|
||
def test_random_crop_video(self): | ||
numFrames = random.randint(4, 128) | ||
height = random.randint(10, 32) * 2 | ||
width = random.randint(10, 32) * 2 | ||
oheight = random.randint(5, (height - 2) / 2) * 2 | ||
owidth = random.randint(5, (width - 2) / 2) * 2 | ||
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) | ||
result = transforms.Compose([ | ||
transforms.ToTensorVideo(), | ||
transforms.RandomCropVideo((oheight, owidth)), | ||
])(clip) | ||
assert result.size(2) == oheight | ||
assert result.size(3) == owidth | ||
|
||
transforms.RandomCropVideo((oheight, owidth)).__repr__() | ||
|
||
def test_random_resized_crop_video(self): | ||
numFrames = random.randint(4, 128) | ||
height = random.randint(10, 32) * 2 | ||
width = random.randint(10, 32) * 2 | ||
oheight = random.randint(5, (height - 2) / 2) * 2 | ||
owidth = random.randint(5, (width - 2) / 2) * 2 | ||
clip = torch.randint(0, 256, (numFrames, height, width, 3), dtype=torch.uint8) | ||
result = transforms.Compose([ | ||
transforms.ToTensorVideo(), | ||
transforms.RandomResizedCropVideo((oheight, owidth)), | ||
])(clip) | ||
assert result.size(2) == oheight | ||
assert result.size(3) == owidth | ||
|
||
transforms.RandomResizedCropVideo((oheight, owidth)).__repr__() | ||
|
||
def test_center_crop_video(self): | ||
numFrames = random.randint(4, 128) | ||
height = random.randint(10, 32) * 2 | ||
width = random.randint(10, 32) * 2 | ||
oheight = random.randint(5, (height - 2) / 2) * 2 | ||
owidth = random.randint(5, (width - 2) / 2) * 2 | ||
|
||
clip = torch.ones((numFrames, height, width, 3), dtype=torch.uint8) * 255 | ||
oh1 = (height - oheight) // 2 | ||
ow1 = (width - owidth) // 2 | ||
clipNarrow = clip[:, oh1:oh1 + oheight, ow1:ow1 + owidth, :] | ||
clipNarrow.fill_(0) | ||
result = transforms.Compose([ | ||
transforms.ToTensorVideo(), | ||
transforms.CenterCropVideo((oheight, owidth)), | ||
])(clip) | ||
|
||
msg = "height: " + str(height) + " width: " \ | ||
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) | ||
self.assertEqual(result.sum().item(), 0, msg) | ||
|
||
oheight += 1 | ||
owidth += 1 | ||
result = transforms.Compose([ | ||
transforms.ToTensorVideo(), | ||
transforms.CenterCropVideo((oheight, owidth)), | ||
])(clip) | ||
sum1 = result.sum() | ||
|
||
msg = "height: " + str(height) + " width: " \ | ||
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) | ||
self.assertEqual(sum1.item() > 1, True, msg) | ||
|
||
oheight += 1 | ||
owidth += 1 | ||
result = transforms.Compose([ | ||
transforms.ToTensorVideo(), | ||
transforms.CenterCropVideo((oheight, owidth)), | ||
])(clip) | ||
sum2 = result.sum() | ||
|
||
msg = "height: " + str(height) + " width: " \ | ||
+ str(width) + " oheight: " + str(oheight) + " owidth: " + str(owidth) | ||
self.assertTrue(sum2.item() > 1, msg) | ||
self.assertTrue(sum2.item() > sum1.item(), msg) | ||
|
||
@unittest.skipIf(stats is None, 'scipy.stats is not available') | ||
def test_normalize_video(self): | ||
def samples_from_standard_normal(tensor): | ||
p_value = stats.kstest(list(tensor.view(-1)), 'norm', args=(0, 1)).pvalue | ||
return p_value > 0.0001 | ||
|
||
random_state = random.getstate() | ||
random.seed(42) | ||
for channels in [1, 3]: | ||
numFrames = random.randint(4, 128) | ||
height = random.randint(32, 256) | ||
width = random.randint(32, 256) | ||
mean = random.random() | ||
std = random.random() | ||
clip = torch.normal(mean, std, size=(channels, numFrames, height, width)) | ||
mean = [clip[c].mean().item() for c in range(channels)] | ||
std = [clip[c].std().item() for c in range(channels)] | ||
normalized = transforms.NormalizeVideo(mean, std)(clip) | ||
assert samples_from_standard_normal(normalized) | ||
random.setstate(random_state) | ||
|
||
# Checking the optional in-place behaviour | ||
tensor = torch.rand((3, 128, 16, 16)) | ||
tensor_inplace = transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)(tensor) | ||
assert torch.equal(tensor, tensor_inplace) | ||
|
||
transforms.NormalizeVideo((0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True).__repr__() | ||
|
||
def test_to_tensor_video(self): | ||
numFrames, height, width = 64, 4, 4 | ||
trans = transforms.ToTensorVideo() | ||
|
||
with self.assertRaises(TypeError): | ||
trans(np.random.rand(numFrames, height, width, 1).tolist()) | ||
trans(torch.rand((numFrames, height, width, 1), dtype=torch.float)) | ||
|
||
with self.assertRaises(ValueError): | ||
trans(torch.ones((3, numFrames, height, width, 3), dtype=torch.uint8)) | ||
trans(torch.ones((height, width, 3), dtype=torch.uint8)) | ||
trans(torch.ones((width, 3), dtype=torch.uint8)) | ||
trans(torch.ones((3), dtype=torch.uint8)) | ||
|
||
trans.__repr__() | ||
|
||
@unittest.skipIf(stats is None, 'scipy.stats not available') | ||
def test_random_horizontal_flip_video(self): | ||
random_state = random.getstate() | ||
random.seed(42) | ||
clip = torch.rand((3, 4, 112, 112), dtype=torch.float) | ||
hclip = clip.flip((-1)) | ||
|
||
num_samples = 250 | ||
num_horizontal = 0 | ||
for _ in range(num_samples): | ||
out = transforms.RandomHorizontalFlipVideo()(clip) | ||
if torch.all(torch.eq(out, hclip)): | ||
num_horizontal += 1 | ||
|
||
p_value = stats.binom_test(num_horizontal, num_samples, p=0.5) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
num_samples = 250 | ||
num_horizontal = 0 | ||
for _ in range(num_samples): | ||
out = transforms.RandomHorizontalFlipVideo(p=0.7)(clip) | ||
if torch.all(torch.eq(out, hclip)): | ||
num_horizontal += 1 | ||
|
||
p_value = stats.binom_test(num_horizontal, num_samples, p=0.7) | ||
random.setstate(random_state) | ||
assert p_value > 0.0001 | ||
|
||
transforms.RandomHorizontalFlipVideo().__repr__() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,2 @@ | ||
from .transforms import * | ||
from .transforms_video import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch | ||
|
||
|
||
def _is_tensor_video_clip(clip): | ||
if not torch.is_tensor(clip): | ||
raise TypeError("clip should be Tesnor. Got %s" % type(clip)) | ||
|
||
if not clip.ndimension() == 4: | ||
raise ValueError("clip should be 4D. Got %dD" % clip.dim()) | ||
|
||
return True | ||
|
||
|
||
def crop(clip, i, j, h, w): | ||
""" | ||
Args: | ||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | ||
""" | ||
assert len(clip.size()) == 4, "clip should be a 4D tensor" | ||
return clip[..., i:i + h, j:j + w] | ||
|
||
|
||
def resize(clip, target_size, interpolation_mode): | ||
assert len(target_size) == 2, "target size should be tuple (height, width)" | ||
return torch.nn.functional.interpolate( | ||
clip, size=target_size, mode=interpolation_mode | ||
) | ||
|
||
|
||
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"): | ||
""" | ||
Do spatial cropping and resizing to the video clip | ||
Args: | ||
clip (torch.tensor): Video clip to be cropped. Size is (C, T, H, W) | ||
i (int): i in (i,j) i.e coordinates of the upper left corner. | ||
j (int): j in (i,j) i.e coordinates of the upper left corner. | ||
h (int): Height of the cropped region. | ||
w (int): Width of the cropped region. | ||
size (tuple(int, int)): height and width of resized clip | ||
Returns: | ||
clip (torch.tensor): Resized and cropped clip. Size is (C, T, H, W) | ||
""" | ||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" | ||
clip = crop(clip, i, j, h, w) | ||
clip = resize(clip, size, interpolation_mode) | ||
return clip | ||
|
||
|
||
def center_crop(clip, crop_size): | ||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" | ||
h, w = clip.size(-2), clip.size(-1) | ||
th, tw = crop_size | ||
assert h >= th and w >= tw, "height and width must be no smaller than crop_size" | ||
|
||
i = int(round((h - th) / 2.0)) | ||
j = int(round((w - tw) / 2.0)) | ||
return crop(clip, i, j, th, tw) | ||
|
||
|
||
def to_tensor(clip): | ||
""" | ||
Convert tensor data type from uint8 to float, divide value by 255.0 and | ||
permute the dimenions of clip tensor | ||
Args: | ||
clip (torch.tensor, dtype=torch.uint8): Size is (T, H, W, C) | ||
Return: | ||
clip (torch.tensor, dtype=torch.float): Size is (C, T, H, W) | ||
""" | ||
_is_tensor_video_clip(clip) | ||
if not clip.dtype == torch.uint8: | ||
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) | ||
return clip.float().permute(3, 0, 1, 2) / 255.0 | ||
|
||
|
||
def normalize(clip, mean, std, inplace=False): | ||
""" | ||
Args: | ||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) | ||
mean (tuple): pixel RGB mean. Size is (3) | ||
std (tuple): pixel standard deviation. Size is (3) | ||
Returns: | ||
normalized clip (torch.tensor): Size is (C, T, H, W) | ||
""" | ||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" | ||
if not inplace: | ||
clip = clip.clone() | ||
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device) | ||
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device) | ||
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None]) | ||
return clip | ||
|
||
|
||
def hflip(clip): | ||
""" | ||
Args: | ||
clip (torch.tensor): Video clip to be normalized. Size is (C, T, H, W) | ||
Returns: | ||
flipped clip (torch.tensor): Size is (C, T, H, W) | ||
""" | ||
assert _is_tensor_video_clip(clip), "clip should be a 4D torch.tensor" | ||
return clip.flip((-1)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think I'll be using
memory_format
in the data reading functionality, so that this permutation is maybe handled automatically for us, in a safer way.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And I'm also thinking about creating a new transform for performing image type conversions, like https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/image/convert_image_dtype , which would let us perform the scaling for different dtypes