Skip to content

Commit

Permalink
More tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed Mar 8, 2023
1 parent 24b0ac6 commit 134aabf
Show file tree
Hide file tree
Showing 2 changed files with 127 additions and 11 deletions.
17 changes: 9 additions & 8 deletions dali/python/nvidia/dali/auto_aug/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,52 +45,53 @@ def warp_y_param(magnitude):


@augmentation(mag_range=(0, 0.3), randomly_negate=True, as_param=warp_x_param)
def shear_x(sample, shear, fill_value=0, interp_type=None):
def shear_x(sample, shear, fill_value=128, interp_type=None):
mt = fn.transforms.shear(shear=shear)
return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type,
inverse_map=False)


@augmentation(mag_range=(0, 0.3), randomly_negate=True, as_param=warp_y_param)
def shear_y(sample, shear, fill_value=0, interp_type=None):
def shear_y(sample, shear, fill_value=128, interp_type=None):
mt = fn.transforms.shear(shear=shear)
return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type,
inverse_map=False)


@augmentation(mag_range=(0., 1.), randomly_negate=True, as_param=warp_x_param)
def translate_x(sample, rel_offset, shape, fill_value=0, interp_type=None):
def translate_x(sample, rel_offset, shape, fill_value=128, interp_type=None):
offset = rel_offset * shape[-2]
mt = fn.transforms.translation(offset=offset)
return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type,
inverse_map=False)


@augmentation(mag_range=(0, 250), randomly_negate=True, as_param=warp_x_param, name="translate_x")
def translate_x_no_shape(sample, offset, fill_value=0, interp_type=None):
def translate_x_no_shape(sample, offset, fill_value=128, interp_type=None):
mt = fn.transforms.translation(offset=offset)
return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type,
inverse_map=False)


@augmentation(mag_range=(0., 1.), randomly_negate=True, as_param=warp_y_param)
def translate_y(sample, rel_offset, shape, fill_value=0, interp_type=None):
def translate_y(sample, rel_offset, shape, fill_value=128, interp_type=None):
offset = rel_offset * shape[-3]
mt = fn.transforms.translation(offset=offset)
return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type,
inverse_map=False)


@augmentation(mag_range=(0, 250), randomly_negate=True, as_param=warp_y_param, name="translate_y")
def translate_y_no_shape(sample, offset, fill_value=0, interp_type=None):
def translate_y_no_shape(sample, offset, fill_value=128, interp_type=None):
mt = fn.transforms.translation(offset=offset)
return fn.warp_affine(sample, matrix=mt, fill_value=fill_value, interp_type=interp_type,
inverse_map=False)


@augmentation(mag_range=(0, 30), randomly_negate=True)
def rotate(sample, angle, fill_value=0, interp_type=None):
return fn.rotate(sample, angle=angle, fill_value=fill_value, interp_type=interp_type)
def rotate(sample, angle, fill_value=128, interp_type=None, rotate_keep_size=True):
return fn.rotate(sample, angle=angle, fill_value=fill_value, interp_type=interp_type,
keep_size=rotate_keep_size)


def shift_enhance_range(magnitude):
Expand Down
121 changes: 118 additions & 3 deletions dali/test/python/auto_aug/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import os

import numpy as np
import cv2
from PIL import Image, ImageEnhance, ImageOps
from nose2.tools import params

Expand Down Expand Up @@ -50,7 +51,7 @@ def inner(sample, param=None):
return inner


def compare_against_baseline(dali_aug, baseline_op, get_data, batch_size, dev="gpu",
def compare_against_baseline(dali_aug, baseline_op, get_data, batch_size, dev="gpu", eps=1e-7,
max_allowed_error=1e-6, params=None, post_proc=None):

@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0, seed=42)
Expand Down Expand Up @@ -81,7 +82,7 @@ def pipeline():
if post_proc is not None:
output = [post_proc(sample) for sample in output]
ref_output = [post_proc(sample) for sample in ref_output]
check_batch(output, ref_output, max_allowed_error=max_allowed_error)
check_batch(output, ref_output, eps=eps, max_allowed_error=max_allowed_error)


def get_images(dev):
Expand All @@ -93,6 +94,106 @@ def inner():
return inner


@params(("cpu", ), ("gpu", ))
def test_shear_x(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def shear_x_ref(img, magnitude):
return img.transform(img.size, Image.AFFINE, (1, -magnitude, 0, 0, 1, 0), Image.BICUBIC,
fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images(dev)
shear_x = a.shear_x.augmentation(mag_range=(-0.3, 0.3), randomly_negate=False)
magnitudes = shear_x._get_magnitudes(batch_size)
compare_against_baseline(shear_x, pil_baseline(shear_x_ref), data_source, batch_size=batch_size,
dev=dev, params=magnitudes, max_allowed_error=None, eps=3)


@params(("cpu", ), ("gpu", ))
def test_shear_y(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def shear_y_ref(img, magnitude):
return img.transform(img.size, Image.AFFINE, (1, 0, 0, -magnitude, 1, 0), Image.BICUBIC,
fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images(dev)
shear_y = a.shear_y.augmentation(mag_range=(-0.3, 0.3), randomly_negate=False)
magnitudes = shear_y._get_magnitudes(batch_size)
compare_against_baseline(shear_y, pil_baseline(shear_y_ref), data_source, batch_size=batch_size,
dev=dev, params=magnitudes, max_allowed_error=None, eps=3)


@params(("cpu", ), ("gpu", ))
def test_rotate(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def rotate_with_fill(img, magnitude):
rot = img.convert("RGBA").rotate(magnitude)
return Image.composite(rot, Image.new("RGBA", img.size, (128, ) * 4), rot).convert(img.mode)

batch_size = 16
data_source = get_images(dev)
rotate = a.rotate.augmentation(mag_range=(-30, 30), randomly_negate=False)
magnitudes = rotate._get_magnitudes(batch_size)
compare_against_baseline(rotate, pil_baseline(rotate_with_fill), data_source,
batch_size=batch_size, dev=dev, params=magnitudes,
max_allowed_error=None, eps=4)


@params(("cpu", ), ("gpu", ))
def test_brightness(dev):

def brightness_ref(img, magnitude):
return ImageEnhance.Brightness(img).enhance(magnitude)

batch_size = 16
data_source = get_images(dev)
brightness = a.brightness.augmentation(mag_range=(0.1, 1), randomly_negate=False, as_param=None)
magnitudes = brightness._get_magnitudes(batch_size)
compare_against_baseline(brightness, pil_baseline(brightness_ref), data_source,
batch_size=batch_size, max_allowed_error=1, dev=dev, params=magnitudes)


@params(("cpu", ), ("gpu", ))
def test_contrast(dev):

# adapted from test_brightness_contrast.py
def contrast_ref(input, contrast):
output = (0.5 + contrast * (np.float32(input) / 255 - 0.5)) * 255
return np.clip(output, 0, 255)

batch_size = 16
data_source = get_images(dev)
contrast = a.contrast.augmentation(mag_range=(0.1, 1), randomly_negate=False, as_param=None)
magnitudes = contrast._get_magnitudes(batch_size)
compare_against_baseline(contrast, pil_baseline(contrast_ref), data_source,
batch_size=batch_size, max_allowed_error=1, dev=dev, params=magnitudes)


@params(("cpu", ), ("gpu", ))
def test_color(dev):
max_allowed_error = 2

def color_ref(img, magnitude):
return ImageEnhance.Color(img).enhance(magnitude)

batch_size = 16
data_source = get_images(dev)
color = a.color.augmentation(mag_range=(0.1, 1.9), randomly_negate=False, as_param=None)
magnitudes = color._get_magnitudes(batch_size)
compare_against_baseline(color, pil_baseline(color_ref), data_source, batch_size=batch_size,
max_allowed_error=max_allowed_error, dev=dev, params=magnitudes)


@params(("gpu", ))
def test_sharpness(dev):

Expand Down Expand Up @@ -141,7 +242,7 @@ def test_solarize(dev):
@params(("cpu", ), ("gpu", ))
def test_solarize_add(dev):

# the implementation from DeepLearningExamples:
# adapted the implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def solarize_add_ref(image, magnitude):
Expand Down Expand Up @@ -172,6 +273,20 @@ def test_invert(dev):
max_allowed_error=1, dev=dev)


@params(("gpu", ))
def test_equalize(dev):

# pil's equalization uses slightly different formula when
# transforming cumulative-sum of histogram into lookup table than open-cv
def open_cv_ref(img):
img = img.transpose(2, 0, 1)
return np.stack([cv2.equalizeHist(channel) for channel in img], axis=2)

data_source = get_images(dev)
compare_against_baseline(a.equalize, open_cv_ref, data_source, batch_size=16,
max_allowed_error=1, dev=dev)


@params(("cpu", ), ("gpu", ))
def test_auto_contrast(dev):
data_source = get_images(dev)
Expand Down

0 comments on commit 134aabf

Please sign in to comment.