-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Tensor support for some transforms #1104
Changes from all commits
f9788f3
f8a5a17
b13c020
68aa258
7062c78
6c0b75e
934a257
ed3424f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -31,6 +31,15 @@ def _is_tensor_image(img): | |
return torch.is_tensor(img) and img.ndimension() == 3 | ||
|
||
|
||
def _get_image_size(img): | ||
if _is_pil_image(img): | ||
return img.size | ||
elif isinstance(img, torch.Tensor) and img.dim() > 2: | ||
return img.shape[-2:][::-1] | ||
else: | ||
raise TypeError("Unexpected type {}".format(type(img))) | ||
|
||
|
||
def _is_numpy(img): | ||
return isinstance(img, np.ndarray) | ||
|
||
|
@@ -234,26 +243,42 @@ def resize(img, size, interpolation=Image.BILINEAR): | |
Returns: | ||
PIL Image: Resized image. | ||
""" | ||
if not _is_pil_image(img): | ||
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): | ||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
if not (isinstance(size, int) or (isinstance(size, Iterable) and len(size) == 2)): | ||
raise TypeError('Got inappropriate size arg: {}'.format(size)) | ||
|
||
if isinstance(size, int): | ||
w, h = img.size | ||
w, h = _get_image_size(img) | ||
if (w <= h and w == size) or (h <= w and h == size): | ||
return img | ||
if w < h: | ||
ow = size | ||
oh = int(size * h / w) | ||
return img.resize((ow, oh), interpolation) | ||
else: | ||
oh = size | ||
ow = int(size * w / h) | ||
return img.resize((ow, oh), interpolation) | ||
else: | ||
size = (oh, ow) | ||
if _is_pil_image(img): | ||
return img.resize(size[::-1], interpolation) | ||
|
||
# tensor codepath | ||
# TODO maybe move this outside | ||
_PIL_TO_TORCH_INTERP_MODE = { | ||
Image.NEAREST: "nearest", | ||
Image.BILINEAR: "bilinear" | ||
} | ||
should_unsqueeze = False | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this is actually |
||
if img.dim() == 3: | ||
img = img[None] | ||
should_unsqueeze = True | ||
out = torch.nn.functional.interpolate(img, size=size, | ||
mode=_PIL_TO_TORCH_INTERP_MODE[interpolation], | ||
align_corners=False) | ||
if should_unsqueeze: | ||
out = out[0] | ||
return out | ||
|
||
|
||
def scale(*args, **kwargs): | ||
warnings.warn("The use of the transforms.Scale transform is deprecated, " + | ||
|
@@ -362,16 +387,19 @@ def crop(img, i, j, h, w): | |
Returns: | ||
PIL Image: Cropped image. | ||
""" | ||
if not _is_pil_image(img): | ||
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same comment as above |
||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
|
||
return img.crop((j, i, j + w, i + h)) | ||
if _is_pil_image(img): | ||
return img.crop((j, i, j + w, i + h)) | ||
|
||
return img[..., i:(i + h), j:(j + w)] | ||
|
||
|
||
def center_crop(img, output_size): | ||
if isinstance(output_size, numbers.Number): | ||
output_size = (int(output_size), int(output_size)) | ||
w, h = img.size | ||
w, h = _get_image_size(img) | ||
th, tw = output_size | ||
i = int(round((h - th) / 2.)) | ||
j = int(round((w - tw) / 2.)) | ||
|
@@ -410,10 +438,13 @@ def hflip(img): | |
Returns: | ||
PIL Image: Horizontall flipped image. | ||
""" | ||
if not _is_pil_image(img): | ||
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same documentation comment as above |
||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
|
||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
if _is_pil_image(img): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's much cleaner to write explicit |
||
return img.transpose(Image.FLIP_LEFT_RIGHT) | ||
|
||
return img.flip(dims=(-1,)) | ||
|
||
|
||
def _get_perspective_coeffs(startpoints, endpoints): | ||
|
@@ -468,10 +499,13 @@ def vflip(img): | |
Returns: | ||
PIL Image: Vertically flipped image. | ||
""" | ||
if not _is_pil_image(img): | ||
if not (_is_pil_image(img) or isinstance(img, torch.Tensor)): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. doc comment as above |
||
raise TypeError('img should be PIL Image. Got {}'.format(type(img))) | ||
|
||
return img.transpose(Image.FLIP_TOP_BOTTOM) | ||
if _is_pil_image(img): | ||
return img.transpose(Image.FLIP_TOP_BOTTOM) | ||
|
||
return img.flip(dims=(-2,)) | ||
|
||
|
||
def five_crop(img, size): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
the comment above needs to be updated that it takes torch.Tensor (and has to specify what range the Tensor's values have to be)