Skip to content

Commit

Permalink
Even even more
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed Mar 8, 2023
1 parent 72dfbd3 commit b8cd469
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 30 deletions.
15 changes: 12 additions & 3 deletions dali/python/nvidia/dali/auto_aug/augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,10 @@ def contrast_mean_centered(sample, parameter):
mean = fn.reductions.mean(sample, axes=[0, 1])
rgb_weights = types.Constant(np.array([0.299, 0.587, 0.114], dtype=np.float32))
center = fn.reductions.sum(mean * rgb_weights)
return fn.contrast(sample, contrast=parameter, contrast_center=center)
# it could be just `fn.contrast(sample, contrast=parameter, contrast_center=center)`
# but for GPU `sample` the `center` is in GPU mem, and that cannot be passed
# as named arg (i.e. `contrast_center`) to the operator
return fn.cast_like(center + (sample - center) * parameter, sample)


@augmentation(mag_range=(0, 0.9), randomly_negate=True, as_param=shift_enhance_range)
Expand All @@ -151,6 +154,11 @@ def sharpness_kernel_shifted(magnitude):
@augmentation(mag_range=(0, 0.9), randomly_negate=True, as_param=sharpness_kernel,
param_device="gpu")
def sharpness(sample, kernel):
"""
The outputs correspond to PIL's ImageEnhance.Sharpness with the exception for 1px
border around the output. PIL computes convolution with smoothing filter only for
valid positions (no out-of-bounds filter positions) and pads the output with the input.
"""
return fn.experimental.filter(sample, kernel)


Expand Down Expand Up @@ -205,8 +213,9 @@ def invert(sample, _):
@augmentation
def equalize(sample, _):
"""
DALI's equalize is open-cv conformant; the PIL impl uses slightly different formula when
scaling histogram's cumulative sum to create lookup table
DALI's equalize follows OpenCV's histogram equalization.
The PIL uses slightly different formula when transforming histogram's
cumulative sum into lookup table.
"""
return fn.experimental.equalize(sample)

Expand Down
134 changes: 107 additions & 27 deletions dali/test/python/auto_aug/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,15 +52,16 @@ def inner(sample, param=None):


def compare_against_baseline(dali_aug, baseline_op, get_data, batch_size, dev="gpu", eps=1e-7,
max_allowed_error=1e-6, params=None, post_proc=None):
max_allowed_error=1e-6, params=None, post_proc=None, use_shape=False):

@pipeline_def(batch_size=batch_size, num_threads=4, device_id=0, seed=42)
def pipeline():
data = get_data()
op_data = data if dev != "gpu" else data.gpu()
mag_bin = fn.external_source(lambda info: np.array(info.idx_in_batch, dtype=np.int32),
batch=False)
output = dali_aug(op_data, num_magnitude_bins=batch_size, magnitude_bin=mag_bin)
extra = {} if not use_shape else {"shape": fn.shapes(data)}
output = dali_aug(op_data, num_magnitude_bins=batch_size, magnitude_bin=mag_bin, **extra)
return output, data

p = pipeline()
Expand All @@ -85,11 +86,11 @@ def pipeline():
check_batch(output, ref_output, eps=eps, max_allowed_error=max_allowed_error)


def get_images(dev):
def get_images():

def inner():
image, _ = fn.readers.file(name="Reader", file_root=images_dir)
return fn.decoders.image(image, device="cpu" if dev == "cpu" else "mixed")
return fn.decoders.image(image, device="cpu")

return inner

Expand All @@ -105,7 +106,7 @@ def shear_x_ref(img, magnitude):
fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images(dev)
data_source = get_images()
shear_x = a.shear_x.augmentation(mag_range=(-0.3, 0.3), randomly_negate=False)
magnitudes = shear_x._get_magnitudes(batch_size)
compare_against_baseline(shear_x, pil_baseline(shear_x_ref), data_source, batch_size=batch_size,
Expand All @@ -123,13 +124,91 @@ def shear_y_ref(img, magnitude):
fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images(dev)
data_source = get_images()
shear_y = a.shear_y.augmentation(mag_range=(-0.3, 0.3), randomly_negate=False)
magnitudes = shear_y._get_magnitudes(batch_size)
compare_against_baseline(shear_y, pil_baseline(shear_y_ref), data_source, batch_size=batch_size,
dev=dev, params=magnitudes, max_allowed_error=None, eps=3)


@params(("cpu", ), ("gpu", ))
def test_translate_x_no_shape(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def translate_x_ref(img, magnitude):
return img.transform(img.size, Image.AFFINE, (1, 0, -magnitude, 0, 1, 0), Image.BICUBIC,
fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images()
translate_x_no_shape = a.translate_x_no_shape.augmentation(mag_range=(-250, 250),
randomly_negate=False)
magnitudes = translate_x_no_shape._get_magnitudes(batch_size)
compare_against_baseline(translate_x_no_shape, pil_baseline(translate_x_ref), data_source,
batch_size=batch_size, dev=dev, params=magnitudes,
max_allowed_error=None, eps=2)


@params(("cpu", ), ("gpu", ))
def test_translate_x(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def translate_x_ref(img, magnitude):
return img.transform(img.size, Image.AFFINE, (1, 0, -magnitude * img.width, 0, 1, 0),
Image.BICUBIC, fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images()
translate_x = a.translate_x.augmentation(mag_range=(-1, 1), randomly_negate=False)
magnitudes = translate_x._get_magnitudes(batch_size)
compare_against_baseline(translate_x, pil_baseline(translate_x_ref), data_source,
batch_size=batch_size, dev=dev, params=magnitudes,
max_allowed_error=None, eps=2, use_shape=True)


@params(("cpu", ), ("gpu", ))
def test_translate_y_no_shape(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def translate_y_ref(img, magnitude):
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, -magnitude), Image.BICUBIC,
fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images()
translate_y_no_shape = a.translate_y_no_shape.augmentation(mag_range=(-250, 250),
randomly_negate=False)
magnitudes = translate_y_no_shape._get_magnitudes(batch_size)
compare_against_baseline(translate_y_no_shape, pil_baseline(translate_y_ref), data_source,
batch_size=batch_size, dev=dev, params=magnitudes,
max_allowed_error=None, eps=2)


@params(("cpu", ), ("gpu", ))
def test_translate_y(dev):

# adapted implementation from DeepLearningExamples:
# https://github.com/NVIDIA/DeepLearningExamples/blob/master/PyTorch/
# Classification/ConvNets/image_classification/autoaugment.py
def translate_y_ref(img, magnitude):
return img.transform(img.size, Image.AFFINE, (1, 0, 0, 0, 1, -magnitude * img.height),
Image.BICUBIC, fillcolor=(128, ) * 3)

batch_size = 16
data_source = get_images()
translate_y = a.translate_y.augmentation(mag_range=(-1, 1), randomly_negate=False)
magnitudes = translate_y._get_magnitudes(batch_size)
compare_against_baseline(translate_y, pil_baseline(translate_y_ref), data_source,
batch_size=batch_size, dev=dev, params=magnitudes,
max_allowed_error=None, eps=2, use_shape=True)


@params(("cpu", ), ("gpu", ))
def test_rotate(dev):

Expand All @@ -141,7 +220,7 @@ def rotate_with_fill(img, magnitude):
return Image.composite(rot, Image.new("RGBA", img.size, (128, ) * 4), rot).convert(img.mode)

batch_size = 16
data_source = get_images(dev)
data_source = get_images()
rotate = a.rotate.augmentation(mag_range=(-30, 30), randomly_negate=False)
magnitudes = rotate._get_magnitudes(batch_size)
compare_against_baseline(rotate, pil_baseline(rotate_with_fill), data_source,
Expand All @@ -156,8 +235,9 @@ def brightness_ref(img, magnitude):
return ImageEnhance.Brightness(img).enhance(magnitude)

batch_size = 16
data_source = get_images(dev)
brightness = a.brightness.augmentation(mag_range=(0.1, 1), randomly_negate=False, as_param=None)
data_source = get_images()
brightness = a.brightness.augmentation(mag_range=(0.1, 1.9), randomly_negate=False,
as_param=None)
magnitudes = brightness._get_magnitudes(batch_size)
compare_against_baseline(brightness, pil_baseline(brightness_ref), data_source,
batch_size=batch_size, max_allowed_error=1, dev=dev, params=magnitudes)
Expand All @@ -172,22 +252,22 @@ def contrast_ref(input, contrast):
return np.clip(output, 0, 255)

batch_size = 16
data_source = get_images(dev)
contrast = a.contrast.augmentation(mag_range=(0.1, 1), randomly_negate=False, as_param=None)
data_source = get_images()
contrast = a.contrast.augmentation(mag_range=(0.1, 1.9), randomly_negate=False, as_param=None)
magnitudes = contrast._get_magnitudes(batch_size)
compare_against_baseline(contrast, pil_baseline(contrast_ref), data_source,
batch_size=batch_size, max_allowed_error=1, dev=dev, params=magnitudes)


@params(("cpu", ))
@params(("cpu", ), ("gpu", ))
def test_contrast_mean_centered(dev):

def contrast_ref(img, magnitude):
return ImageEnhance.Contrast(img).enhance(magnitude)

batch_size = 16
data_source = get_images(dev)
contrast = a.contrast_mean_centered.augmentation(mag_range=(0.1, 1), randomly_negate=False,
data_source = get_images()
contrast = a.contrast_mean_centered.augmentation(mag_range=(0.1, 1.9), randomly_negate=False,
as_param=None)
magnitudes = contrast._get_magnitudes(batch_size)
compare_against_baseline(contrast, pil_baseline(contrast_ref), data_source,
Expand All @@ -202,7 +282,7 @@ def color_ref(img, magnitude):
return ImageEnhance.Color(img).enhance(magnitude)

batch_size = 16
data_source = get_images(dev)
data_source = get_images()
color = a.color.augmentation(mag_range=(0.1, 1.9), randomly_negate=False, as_param=None)
magnitudes = color._get_magnitudes(batch_size)
compare_against_baseline(color, pil_baseline(color_ref), data_source, batch_size=batch_size,
Expand All @@ -223,7 +303,7 @@ def post_proc(img):
return img[1:-1, 1:-1, :]

batch_size = 16
data_source = get_images(dev)
data_source = get_images()
sharpness = a.sharpness.augmentation(mag_range=(0.1, 1.9), randomly_negate=False,
as_param=a.sharpness_kernel_shifted)
magnitudes = sharpness._get_magnitudes(batch_size)
Expand All @@ -235,7 +315,7 @@ def post_proc(img):
@params(("cpu", ), ("gpu", ))
def test_posterize(dev):
batch_size = 16
data_source = get_images(dev)
data_source = get_images()
# note, 0 is remapped to 1 as in tf implementation referred in the RA paper, thus (1, 8) range
posterize = a.posterize.augmentation(param_device=dev, mag_range=(1, 8))
magnitudes = np.round(posterize._get_magnitudes(batch_size)).astype(np.int32)
Expand All @@ -246,7 +326,7 @@ def test_posterize(dev):
@params(("cpu", ), ("gpu", ))
def test_solarize(dev):
batch_size = 16
data_source = get_images(dev)
data_source = get_images()
solarize = a.solarize.augmentation(param_device=dev)
magnitudes = solarize._get_magnitudes(batch_size)
params = solarize._map_mags_to_params(magnitudes)
Expand All @@ -273,7 +353,7 @@ def solarize_add_ref(image, magnitude):
return ImageOps._lut(image, lut)

batch_size = 16
data_source = get_images(dev)
data_source = get_images()
solarize_add = a.solarize_add.augmentation(param_device=dev)
magnitudes = solarize_add._get_magnitudes(batch_size)
params = solarize_add._map_mags_to_params(magnitudes)
Expand All @@ -283,7 +363,7 @@ def solarize_add_ref(image, magnitude):

@params(("cpu", ), ("gpu", ))
def test_invert(dev):
data_source = get_images(dev)
data_source = get_images()
compare_against_baseline(a.invert, pil_baseline(ImageOps.invert), data_source, batch_size=16,
max_allowed_error=1, dev=dev)

Expand All @@ -293,18 +373,18 @@ def test_equalize(dev):

# pil's equalization uses slightly different formula when
# transforming cumulative-sum of histogram into lookup table than open-cv
def open_cv_ref(img):
img = img.transpose(2, 0, 1)
return np.stack([cv2.equalizeHist(channel) for channel in img], axis=2)
# so the point-wise diffs can be significant, but the average is not
# (comparable to geom transforms)
eps = 5

data_source = get_images(dev)
compare_against_baseline(a.equalize, open_cv_ref, data_source, batch_size=16,
max_allowed_error=1, dev=dev)
data_source = get_images()
compare_against_baseline(a.equalize, pil_baseline(ImageOps.equalize), data_source,
batch_size=16, max_allowed_error=None, dev=dev, eps=eps)


@params(("cpu", ), ("gpu", ))
def test_auto_contrast(dev):
data_source = get_images(dev)
data_source = get_images()
compare_against_baseline(a.auto_contrast, pil_baseline(ImageOps.autocontrast), data_source,
batch_size=16, max_allowed_error=1, dev=dev)

Expand Down

0 comments on commit b8cd469

Please sign in to comment.