Skip to content

Commit

Permalink
Refactor adjust ops tests (#2595)
Browse files Browse the repository at this point in the history
* [WIP] Unify ops Grayscale and RandomGrayscale

* Unified inputs for grayscale op and transforms
- deprecated F.to_grayscale in favor of F.rgb_to_grayscale

* Fixes bug with fp input

* Rewritten adjust_* tests
- split test_adjustments into 3 separate tests
- unified testing approach with test_adjust_gamma

* Added ColorJitter tests

* Relaxed tolerance for functional adjust-* tests

* Removed wrong merge and commented code
  • Loading branch information
vfdev-5 authored Sep 1, 2020
1 parent ab590a4 commit 5f616a2
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 98 deletions.
123 changes: 51 additions & 72 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,64 +111,6 @@ def test_rgb2hsv(self):

self.assertLess(max_diff, 1e-5)

def test_adjustments(self):
script_adjust_brightness = torch.jit.script(F_t.adjust_brightness)
script_adjust_contrast = torch.jit.script(F_t.adjust_contrast)
script_adjust_saturation = torch.jit.script(F_t.adjust_saturation)

fns = ((F.adjust_brightness, F_t.adjust_brightness, script_adjust_brightness),
(F.adjust_contrast, F_t.adjust_contrast, script_adjust_contrast),
(F.adjust_saturation, F_t.adjust_saturation, script_adjust_saturation))

for _ in range(20):
channels = 3
dims = torch.randint(1, 50, (2,))
shape = (channels, dims[0], dims[1])

if torch.randint(0, 2, (1,)) == 0:
img = torch.rand(*shape, dtype=torch.float, device=self.device)
else:
img = torch.randint(0, 256, shape, dtype=torch.uint8, device=self.device)

factor = 3 * torch.rand(1).item()
img_clone = img.clone()
for f, ft, sft in fns:

ft_img = ft(img, factor).cpu()
sft_img = sft(img, factor).cpu()
if not img.dtype.is_floating_point:
ft_img = ft_img.to(torch.float) / 255
sft_img = sft_img.to(torch.float) / 255

img_pil = transforms.ToPILImage()(img)
f_img_pil = f(img_pil, factor)
f_img = transforms.ToTensor()(f_img_pil)

# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (ft_img - f_img).abs().max()
max_diff_scripted = (sft_img - f_img).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
self.assertTrue(torch.equal(img, img_clone))

# test for class interface
f = transforms.ColorJitter(brightness=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(contrast=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(saturation=factor)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

f = transforms.ColorJitter(brightness=1)
scripted_fn = torch.jit.script(f)
scripted_fn(img)

def test_rgb_to_grayscale(self):
script_rgb_to_grayscale = torch.jit.script(F.rgb_to_grayscale)

Expand Down Expand Up @@ -267,32 +209,69 @@ def test_pad(self):
with self.assertRaises(ValueError, msg="Padding can not be negative for symmetric padding_mode"):
F_t.pad(tensor, (-2, -3), padding_mode="symmetric")

def test_adjust_gamma(self):
script_fn = torch.jit.script(F.adjust_gamma)
tensor, pil_img = self._create_data(26, 36, device=self.device)
def _test_adjust_fn(self, fn, fn_pil, fn_t, configs):
script_fn = torch.jit.script(fn)

for dt in [torch.float64, torch.float32, None]:
torch.manual_seed(15)

tensor, pil_img = self._create_data(26, 34, device=self.device)

for dt in [None, torch.float32, torch.float64]:

if dt is not None:
tensor = F.convert_image_dtype(tensor, dt)

gammas = [0.8, 1.0, 1.2]
gains = [0.7, 1.0, 1.3]
for gamma, gain in zip(gammas, gains):
for config in configs:

adjusted_tensor = F.adjust_gamma(tensor, gamma, gain)
adjusted_pil = F.adjust_gamma(pil_img, gamma, gain)
scripted_result = script_fn(tensor, gamma, gain)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1])
adjusted_tensor = fn_t(tensor, **config)
adjusted_pil = fn_pil(pil_img, **config)
scripted_result = script_fn(tensor, **config)
msg = "{}, {}".format(dt, config)
self.assertEqual(adjusted_tensor.dtype, scripted_result.dtype, msg=msg)
self.assertEqual(adjusted_tensor.size()[1:], adjusted_pil.size[::-1], msg=msg)

rbg_tensor = adjusted_tensor

if adjusted_tensor.dtype != torch.uint8:
rbg_tensor = F.convert_image_dtype(adjusted_tensor, torch.uint8)

self.compareTensorToPIL(rbg_tensor, adjusted_pil)
# Check that max difference does not exceed 2 in [0, 255] range
# Exact matching is not possible due to incompatibility convert_image_dtype and PIL results
tol = 2.0 + 1e-10
self.approxEqualTensorToPIL(rbg_tensor.float(), adjusted_pil, tol, msg=msg, agg_method="max")
self.assertTrue(adjusted_tensor.allclose(scripted_result), msg=msg)

def test_adjust_brightness(self):
self._test_adjust_fn(
F.adjust_brightness,
F_pil.adjust_brightness,
F_t.adjust_brightness,
[{"brightness_factor": f} for f in [0.1, 0.5, 1.0, 1.34, 2.5]]
)

def test_adjust_contrast(self):
self._test_adjust_fn(
F.adjust_contrast,
F_pil.adjust_contrast,
F_t.adjust_contrast,
[{"contrast_factor": f} for f in [0.2, 0.5, 1.0, 1.5, 2.0]]
)

self.assertTrue(adjusted_tensor.allclose(scripted_result))
def test_adjust_saturation(self):
self._test_adjust_fn(
F.adjust_saturation,
F_pil.adjust_saturation,
F_t.adjust_saturation,
[{"saturation_factor": f} for f in [0.5, 0.75, 1.0, 1.5, 2.0]]
)

def test_adjust_gamma(self):
self._test_adjust_fn(
F.adjust_gamma,
F_pil.adjust_gamma,
F_t.adjust_gamma,
[{"gamma": g1, "gain": g2} for g1, g2 in zip([0.8, 1.0, 1.2], [0.7, 1.0, 1.3])]
)

def test_resize(self):
script_fn = torch.jit.script(F_t.resize)
Expand Down
47 changes: 21 additions & 26 deletions test/test_transforms_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _test_class_op(self, method, meth_kwargs=None, test_exact_match=True, **matc
if meth_kwargs is None:
meth_kwargs = {}

tensor, pil_img = self._create_data(height=10, width=10, device=self.device)
tensor, pil_img = self._create_data(26, 34, device=self.device)
# test for class interface
f = getattr(T, method)(**meth_kwargs)
scripted_fn = torch.jit.script(f)
Expand Down Expand Up @@ -57,31 +57,26 @@ def test_random_horizontal_flip(self):
def test_random_vertical_flip(self):
self._test_op('vflip', 'RandomVerticalFlip')

def test_adjustments(self):
fns = ['adjust_brightness', 'adjust_contrast', 'adjust_saturation']
for _ in range(20):
factor = 3 * torch.rand(1).item()
tensor, _ = self._create_data(device=self.device)
pil_img = T.ToPILImage()(tensor)

for func in fns:
adjusted_tensor = getattr(F, func)(tensor, factor)
adjusted_pil_img = getattr(F, func)(pil_img, factor)

adjusted_pil_tensor = T.ToTensor()(adjusted_pil_img).to(self.device)
scripted_fn = torch.jit.script(getattr(F, func))
adjusted_tensor_script = scripted_fn(tensor, factor)

if not tensor.dtype.is_floating_point:
adjusted_tensor = adjusted_tensor.to(torch.float) / 255
adjusted_tensor_script = adjusted_tensor_script.to(torch.float) / 255

# F uses uint8 and F_t uses float, so there is a small
# difference in values caused by (at most 5) truncations.
max_diff = (adjusted_tensor - adjusted_pil_tensor).abs().max()
max_diff_scripted = (adjusted_tensor - adjusted_tensor_script).abs().max()
self.assertLess(max_diff, 5 / 255 + 1e-5)
self.assertLess(max_diff_scripted, 5 / 255 + 1e-5)
def test_color_jitter(self):

tol = 1.0 + 1e-10
for f in [0.1, 0.5, 1.0, 1.34]:
meth_kwargs = {"brightness": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

for f in [0.2, 0.5, 1.0, 1.5]:
meth_kwargs = {"contrast": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

for f in [0.5, 0.75, 1.0, 1.25]:
meth_kwargs = {"saturation": f}
self._test_class_op(
"ColorJitter", meth_kwargs=meth_kwargs, test_exact_match=False, tol=tol, agg_method="max"
)

def test_pad(self):

Expand Down

0 comments on commit 5f616a2

Please sign in to comment.