Skip to content

Commit

Permalink
Added typing annotations to transforms/functional_pil (#4234)
Browse files Browse the repository at this point in the history
* fix

* add functional PIL typings

* fix types

* fix types

* fix a small one

* small fix

* fix type

* fix interpolation types

Co-authored-by: Vasilis Vryniotis <[email protected]>
  • Loading branch information
oke-aditya and datumbox authored Aug 18, 2021
1 parent 98cb4ea commit 759c5b6
Showing 1 changed file with 71 additions and 22 deletions.
93 changes: 71 additions & 22 deletions torchvision/transforms/functional_pil.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import numbers
from typing import Any, List, Sequence
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -34,23 +34,23 @@ def _get_image_num_channels(img: Any) -> int:


@torch.jit.unused
def hflip(img):
def hflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_LEFT_RIGHT)


@torch.jit.unused
def vflip(img):
def vflip(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.transpose(Image.FLIP_TOP_BOTTOM)


@torch.jit.unused
def adjust_brightness(img, brightness_factor):
def adjust_brightness(img: Image.Image, brightness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -60,7 +60,7 @@ def adjust_brightness(img, brightness_factor):


@torch.jit.unused
def adjust_contrast(img, contrast_factor):
def adjust_contrast(img: Image.Image, contrast_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -70,7 +70,7 @@ def adjust_contrast(img, contrast_factor):


@torch.jit.unused
def adjust_saturation(img, saturation_factor):
def adjust_saturation(img: Image.Image, saturation_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -80,7 +80,7 @@ def adjust_saturation(img, saturation_factor):


@torch.jit.unused
def adjust_hue(img, hue_factor):
def adjust_hue(img: Image.Image, hue_factor: float) -> Image.Image:
if not(-0.5 <= hue_factor <= 0.5):
raise ValueError('hue_factor ({}) is not in [-0.5, 0.5].'.format(hue_factor))

Expand All @@ -104,7 +104,12 @@ def adjust_hue(img, hue_factor):


@torch.jit.unused
def adjust_gamma(img, gamma, gain=1):
def adjust_gamma(
img: Image.Image,
gamma: float,
gain: float = 1.0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -121,7 +126,13 @@ def adjust_gamma(img, gamma, gain=1):


@torch.jit.unused
def pad(img, padding, fill=0, padding_mode="constant"):
def pad(
img: Image.Image,
padding: Union[int, List[int], Tuple[int, ...]],
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
padding_mode: str = "constant",
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))

Expand Down Expand Up @@ -196,15 +207,28 @@ def pad(img, padding, fill=0, padding_mode="constant"):


@torch.jit.unused
def crop(img: Image.Image, top: int, left: int, height: int, width: int) -> Image.Image:
def crop(
img: Image.Image,
top: int,
left: int,
height: int,
width: int,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

return img.crop((left, top, left + width, top + height))


@torch.jit.unused
def resize(img, size, interpolation=Image.BILINEAR, max_size=None):
def resize(
img: Image.Image,
size: Union[Sequence[int], int],
interpolation: int = Image.BILINEAR,
max_size: Optional[int] = None,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
if not (isinstance(size, int) or (isinstance(size, Sequence) and len(size) in (1, 2))):
Expand Down Expand Up @@ -242,7 +266,12 @@ def resize(img, size, interpolation=Image.BILINEAR, max_size=None):


@torch.jit.unused
def _parse_fill(fill, img, name="fillcolor"):
def _parse_fill(
fill: Optional[Union[float, List[float], Tuple[float, ...]]],
img: Image.Image,
name: str = "fillcolor",
) -> Dict[str, Optional[Union[float, List[float], Tuple[float, ...]]]]:

# Process fill color for affine transforms
num_bands = len(img.getbands())
if fill is None:
Expand All @@ -261,7 +290,13 @@ def _parse_fill(fill, img, name="fillcolor"):


@torch.jit.unused
def affine(img, matrix, interpolation=0, fill=None):
def affine(
img: Image.Image,
matrix: List[float],
interpolation: int = Image.NEAREST,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -271,7 +306,15 @@ def affine(img, matrix, interpolation=0, fill=None):


@torch.jit.unused
def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):
def rotate(
img: Image.Image,
angle: float,
interpolation: int = Image.NEAREST,
expand: bool = False,
center: Optional[Tuple[int, int]] = None,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))

Expand All @@ -280,7 +323,13 @@ def rotate(img, angle, interpolation=0, expand=False, center=None, fill=None):


@torch.jit.unused
def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None):
def perspective(
img: Image.Image,
perspective_coeffs: float,
interpolation: int = Image.BICUBIC,
fill: Optional[Union[float, List[float], Tuple[float, ...]]] = 0,
) -> Image.Image:

if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -290,7 +339,7 @@ def perspective(img, perspective_coeffs, interpolation=Image.BICUBIC, fill=None)


@torch.jit.unused
def to_grayscale(img, num_output_channels):
def to_grayscale(img: Image.Image, num_output_channels: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -308,28 +357,28 @@ def to_grayscale(img, num_output_channels):


@torch.jit.unused
def invert(img):
def invert(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.invert(img)


@torch.jit.unused
def posterize(img, bits):
def posterize(img: Image.Image, bits: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.posterize(img, bits)


@torch.jit.unused
def solarize(img, threshold):
def solarize(img: Image.Image, threshold: int) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.solarize(img, threshold)


@torch.jit.unused
def adjust_sharpness(img, sharpness_factor):
def adjust_sharpness(img: Image.Image, sharpness_factor: float) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))

Expand All @@ -339,14 +388,14 @@ def adjust_sharpness(img, sharpness_factor):


@torch.jit.unused
def autocontrast(img):
def autocontrast(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.autocontrast(img)


@torch.jit.unused
def equalize(img):
def equalize(img: Image.Image) -> Image.Image:
if not _is_pil_image(img):
raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
return ImageOps.equalize(img)

0 comments on commit 759c5b6

Please sign in to comment.