Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Making Vflip and Hflip in Tensor format #1465

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torchvision.transforms.functional_tensor as F_t
import unittest
import torch

class Tester(unittest.TestCase):

def test_vflip(self):
img_tensor = torch.randn(3,16,16)
vflipped_img = F_t.vflip(img_tensor)
vflipped_img_again = F_t.vflip(vflipped_img)

assert vflipped_img.shape == img_tensor.shape
Copy link
Member

@fmassa fmassa Oct 15, 2019

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use self.assertEqual here? I'd prefer to move away from the raw asserts and use self.assert* methods from unittest for newer test

assert torch.equal(img_tensor, vflipped_img_again)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I'd prefer using self.assertTrue here.


def test_hflip(self):
img_tensor = torch.randn(3,16,16)
hflipped_img = F_t.hflip(img_tensor)
hflipped_img_again = F_t.hflip(hflipped_img)

assert hflipped_img.shape == img_tensor.shape
assert torch.equal(img_tensor, hflipped_img_again)

if __name__ == '__main__':
unittest.main()
32 changes: 32 additions & 0 deletions torchvision/transforms/functional_tensor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import torch
import torchvision.transforms.functional as F

def vflip(img_tensor):
"""Vertically flip the given the Image Tensor.

Args:
img_tensor (Tensor): Image Tensor to be flipped.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you specify in the documentation what is the expected format of the tensor? Is it HxWxC or CxHxW?


Returns:
Tensor: Vertically flipped image Tensor.
"""
if not F._is_tensor_image(img_tensor):
raise TypeError('tensor is not a torch image.')

return img_tensor.flip(1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you make the offset relative to the end of the dimension?
This way, we could make it also work for batches of images / videos.
So it would be something like

return img_tensor.flip(-2)



def hflip(img_tensor):
"""Horizontally flip the given the Image Tensor.

Args:
img_tensor (Tensor): Image Tensor to be flipped.

Returns:
Tensor: Horizontally flipped image Tensor.
"""

if not F._is_tensor_image(img_tensor):
raise TypeError('tensor is not a torch image.')

return img_tensor.flip(2)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment, can you do .flip(-1) here?