diff --git a/dali/python/nvidia/dali/auto_aug/auto_augment.py b/dali/python/nvidia/dali/auto_aug/auto_augment.py index 7ca11647d70..5e497c4d25e 100644 --- a/dali/python/nvidia/dali/auto_aug/auto_augment.py +++ b/dali/python/nvidia/dali/auto_aug/auto_augment.py @@ -126,8 +126,8 @@ def apply_auto_augment(policy: Policy, sample: _DataNode, seed: Optional[int] = return sample -def get_image_net_policy(use_shape: bool = False, max_translate_abs: int = None, - max_translate_rel: float = None) -> Policy: +def get_image_net_policy(use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None) -> Policy: """ Creates augmentation policy tuned for the ImageNet as described in AutoAugment (https://arxiv.org/abs/1805.09501). @@ -189,8 +189,8 @@ def get_image_net_policy(use_shape: bool = False, max_translate_abs: int = None, ]) -def _get_translate_y(use_shape: bool = False, max_translate_abs: int = None, - max_translate_rel: float = None): +def _get_translate_y(use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None): max_translate_height, _ = _parse_validate_offset(use_shape, max_translate_abs=max_translate_abs, max_translate_rel=max_translate_rel, default_translate_abs=250, diff --git a/dali/python/nvidia/dali/auto_aug/core/decorator.py b/dali/python/nvidia/dali/auto_aug/core/decorator.py index a4f28c5d3c1..2e4dd402d02 100644 --- a/dali/python/nvidia/dali/auto_aug/core/decorator.py +++ b/dali/python/nvidia/dali/auto_aug/core/decorator.py @@ -39,7 +39,7 @@ def augmentation(function: Optional[Callable[..., _DataNode]] = None, *, randomly_negate: Optional[bool] = None, as_param: Optional[Callable[[float], _ArrayLike]] = None, param_device: Optional[str] = None, name: Optional[str] = None, - augmentation_cls: Type[Augmentation] = None): + augmentation_cls: Optional[Type[Augmentation]] = None): """ A decorator turning transformations implemented with DALI into augmentations that can be used by automatic augmentations (e.g. AutoAugment, RandAugment, TrivialAugment). diff --git a/dali/python/nvidia/dali/auto_aug/core/policy.py b/dali/python/nvidia/dali/auto_aug/core/policy.py index 1c9d778f425..394ded2f97d 100644 --- a/dali/python/nvidia/dali/auto_aug/core/policy.py +++ b/dali/python/nvidia/dali/auto_aug/core/policy.py @@ -90,7 +90,7 @@ def __repr__(self): def _sub_policy_with_unique_names( sub_policies: Sequence[Sequence[Tuple[Augmentation, float, int]]] -) -> Tuple[Tuple[Tuple[Augmentation, float, int]]]: +) -> Sequence[Sequence[Tuple[Augmentation, float, int]]]: """ Check if the augmentations used in the sub-policies have unique names. If not, rename them by adding enumeration to the names. diff --git a/dali/python/nvidia/dali/auto_aug/rand_augment.py b/dali/python/nvidia/dali/auto_aug/rand_augment.py new file mode 100644 index 00000000000..b1025073595 --- /dev/null +++ b/dali/python/nvidia/dali/auto_aug/rand_augment.py @@ -0,0 +1,256 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +from typing import List, Optional + +from nvidia.dali import fn +from nvidia.dali import types +from nvidia.dali.auto_aug import augmentations as a +from nvidia.dali.auto_aug.core import signed_bin, _Augmentation +from nvidia.dali.auto_aug.core._args import \ + forbid_unused_kwargs as _forbid_unused_kwargs +from nvidia.dali.auto_aug.core._utils import \ + parse_validate_offset as _parse_validate_offset, \ + pretty_select as _pretty_select +from nvidia.dali.data_node import DataNode as _DataNode + + +def rand_augment(sample: _DataNode, n: int, m: int, num_magnitude_bins: int = 31, + shape: Optional[_DataNode] = None, fill_value: Optional[int] = 128, + interp_type: Optional[types.DALIInterpType] = None, + max_translate_abs: Optional[int] = None, max_translate_rel: Optional[float] = None, + seed: Optional[int] = None, monotonic_mag: bool = True, + excluded: Optional[List[str]] = None) -> _DataNode: + """ + Applies RandAugment (https://arxiv.org/abs/1909.13719) augmentation scheme to the + provided batch of sample. + + Parameter + --------- + sample : DataNode + A batch of samples to be processed. The samples should be images of `HWC` layout, + `uint8` type and reside on GPU. + n: int + The number of randomly sampled operations to be applied to a sample. + m: int + A magnitude (strength) of each operation to be applied, it must be an integer + within `[0, num_magnitude_bins - 1]`. + num_magnitude_bins: int, optional + The number of bins to divide the magnitude ranges into. + shape: DataNode, optional + A batch of shapes of the samples. If specified, the magnitude of `translation` + operations depends on the image shape and spans from 0 to `max_translate_rel * shape`. + Otherwise, the magnitude range is `[0, max_translate_abs]` for any sample. + fill_value: int, optional + A value to be used as a padding for images transformed with warp_affine ops + (translation, shear and rotate). If `None` is specified, the images are padded + with the border value repeated (clamped). + interp_type: types.DALIInterpType, optional + Interpolation method used by the warp_affine ops (translation, shear and rotate). + Supported values are `types.INTERP_LINEAR` (default) and `types.INTERP_NN`. + seed: int, optional + Seed to be used to randomly sample operations (and to negate magnitudes). + monotonic_mag: bool, optional + There are two flavours of RandAugment available in different frameworks. For the default + `monotonic_mag=True` the strength of operations that accept magnitude bins increases with + the increasing bins. If set to False, the magnitude ranges for some color operations differ. + There, the `posterize` and `solarize` strength decreases with increasing magnitude bins and + enhance operations (`brightness`, `contrast`, `color`, `sharpness`) use (0.1, 1.9) range, + which means that the strength decreases the closer the magnitudes are to the center + of the range. See `get_rand_augment_non_monotonic_suite`. + excluded: List[str], optional + A list of names of the operations to be excluded from the default suite of augmentations. + If, instead of just limiting the set of operations, you need to include some custom + operations or fine-tune the existing ones, you can use the `apply_rand_augment` + directly, which accepts a list of augmentations. + + Returns + ------- + DataNode + A batch of transformed samples. + """ + aug_kwargs = {"fill_value": fill_value, "interp_type": interp_type} + use_shape = shape is not None + if use_shape: + aug_kwargs["shape"] = shape + if monotonic_mag: + augmentations = get_rand_augment_suite(use_shape, max_translate_abs, max_translate_rel) + else: + augmentations = get_rand_augment_non_monotonic_suite(use_shape, max_translate_abs, + max_translate_rel) + augmentation_names = set(aug.name for aug in augmentations) + assert len(augmentation_names) == len(augmentations) + excluded = excluded or [] + for name in excluded: + if name not in augmentation_names: + raise Exception(f"The `{name}` was specified in `excluded`, but the RandAugment suite " + f"does not contain augmentation with this name. " + f"The augmentations in the suite are: {', '.join(augmentation_names)}.") + selected_augments = [aug for aug in augmentations if aug.name not in excluded] + return apply_rand_augment(selected_augments, sample, n, m, + num_magnitude_bins=num_magnitude_bins, seed=seed, **aug_kwargs) + + +def apply_rand_augment(augmentations: List[_Augmentation], sample: _DataNode, n: int, m: int, + num_magnitude_bins: int = 31, seed: Optional[int] = None, + **kwargs) -> _DataNode: + """ + Applies the list of `augmentations` in RandAugment (https://arxiv.org/abs/1909.13719) fashion. + Each sample is transformed with `n` operations in a sequence randomly selected from the + `augmentations` list. Each operation uses `m` as the magnitude bin. + + Parameter + --------- + augmentations : List[core._Augmentation] + List of augmentations to be sampled and applied in RandAugment fashion. + sample : DataNode + A batch of samples to be processed. + n: int + The number of randomly sampled operations to be applied to a sample. + m: int + A magnitude bin (strength) of each operation to be applied, it must be an integer + within `[0, num_magnitude_bins - 1]`. + num_magnitude_bins: int + The number of bins to divide the magnitude ranges into. + seed: int + Seed to be used to randomly sample operations (and to negate magnitudes). + kwargs: + Any extra parameters to be passed when calling `augmentations`. + The signature of each augmentation is checked for any extra arguments and if + the name of the argument matches one from the `kwargs`, the value is + passed as an argument. For example, some augmentations from the default + random augment suite accept `shapes`, `fill_value` and `interp_type`. + Returns + ------- + DataNode + A batch of transformed samples. + """ + if not isinstance(n, int) or n < 0: + raise Exception( + f"The number of operations to apply `n` must be a non-negative integer, got {n}.") + if not isinstance(num_magnitude_bins, int) or num_magnitude_bins < 1: + raise Exception( + f"The `num_magnitude_bins` must be a positive integer, got {num_magnitude_bins}.") + if not isinstance(m, int) or not 0 <= m < num_magnitude_bins: + raise Exception(f"The magnitude bin `m` must be an integer from " + f"`[0, {num_magnitude_bins - 1}]` range. Got {m}.") + if n == 0: + warnings.warn( + "The `apply_rand_augment` was called with `n=0`, " + "no augmentation will be applied.", Warning) + return sample + if len(augmentations) == 0: + raise Exception("The `augmentations` list cannot be empty, unless n=0. " + "Got empty list in `apply_rand_augment` call.") + shape = tuple() if n == 1 else (n, ) + op_idx = fn.random.uniform(values=list(range(len(augmentations))), seed=seed, shape=shape, + dtype=types.INT32) + use_signed_magnitudes = any(aug.randomly_negate for aug in augmentations) + _forbid_unused_kwargs(augmentations, kwargs, 'apply_rand_augment') + for level_idx in range(n): + magnitude_bin = signed_bin(m) if use_signed_magnitudes else m + op_kwargs = dict(sample=sample, magnitude_bin=magnitude_bin, + num_magnitude_bins=num_magnitude_bins, **kwargs) + level_op_idx = op_idx if n == 1 else op_idx[level_idx] + sample = _pretty_select(augmentations, level_op_idx, op_kwargs, + auto_aug_name='apply_rand_augment', + ref_suite_name='get_rand_augment_suite') + return sample + + +def get_rand_augment_suite(use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None) -> List[_Augmentation]: + """ + Creates a list of RandAugment augmentations. + + Parameter + --------- + use_shape : bool + If true, the translation offset is computed as a percentage of the image. Useful if the + images processed with the auto augment have different shapes. If false, the offsets range + is bounded by a constant (`max_translate_abs`). + max_translate_abs: int or (int, int), optional + Only valid with use_shape=False, specifies the maximal shift (in pixels) in the translation + augmentations. If tuple is specified, the first component limits height, the second the + width. + max_translate_rel: float or (float, float), optional + Only valid with use_shape=True, specifies the maximal shift as a fraction of image shape + in the translation augmentations. If tuple is specified, the first component limits + height, the second the width. + """ + # translations = [translate_x, translate_y] with adjusted magnitude range + translations = _get_translations(use_shape, max_translate_abs, max_translate_rel) + # [.augmentation((mag_low, mag_high), randomly_negate_mag, magnitude_to_param_custom_mapping] + return translations + [ + a.shear_x.augmentation((0, 0.3), True), + a.shear_y.augmentation((0, 0.3), True), + a.rotate.augmentation((0, 30), True), + a.brightness.augmentation((0, 0.9), True, a.shift_enhance_range), + a.contrast.augmentation((0, 0.9), True, a.shift_enhance_range), + a.color.augmentation((0, 0.9), True, a.shift_enhance_range), + a.sharpness.augmentation((0, 0.9), True, a.sharpness_kernel), + a.posterize.augmentation((8, 4), False, a.poster_mask_uint8), + # solarization strength increases with decreasing magnitude (threshold) + a.solarize.augmentation((256, 0)), + a.equalize, + a.auto_contrast, + a.identity, + ] + + +def get_rand_augment_non_monotonic_suite( + use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None) -> List[_Augmentation]: + """ + Similarly to `get_rand_augment_suite` creates a list of RandAugment augmentations. + + This variant uses brightness, contrast, color, sharpness, posterize, and solarize + with magnitude ranges as used by the AutoAugment. However, those ranges do not meet + the intuition that the bigger magnitude bin corresponds to stronger operation. + """ + # translations = [translate_x, translate_y] with adjusted magnitude range + translations = _get_translations(use_shape, max_translate_abs, max_translate_rel) + return translations + [ + a.shear_x.augmentation((0, 0.3), True), + a.shear_y.augmentation((0, 0.3), True), + a.rotate.augmentation((0, 30), True), + a.brightness.augmentation((0.1, 1.9), False, None), + a.contrast.augmentation((0.1, 1.9), False, None), + a.color.augmentation((0.1, 1.9), False, None), + a.sharpness.augmentation((0.1, 1.9), False, a.sharpness_kernel_shifted), + a.posterize.augmentation((0, 4), False, a.poster_mask_uint8), + a.solarize.augmentation((0, 256), False, None), + a.equalize, + a.auto_contrast, + a.identity, + ] + + +def _get_translations(use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None) -> List[_Augmentation]: + max_translate_height, max_translate_width = _parse_validate_offset( + use_shape, max_translate_abs=max_translate_abs, max_translate_rel=max_translate_rel, + default_translate_abs=100, default_translate_rel=100 / 224) + if use_shape: + return [ + a.translate_x.augmentation((0, max_translate_width), True), + a.translate_y.augmentation((0, max_translate_height), True), + ] + else: + return [ + a.translate_x_no_shape.augmentation((0, max_translate_width), True), + a.translate_y_no_shape.augmentation((0, max_translate_height), True), + ] diff --git a/dali/python/nvidia/dali/auto_aug/trivial_augment.py b/dali/python/nvidia/dali/auto_aug/trivial_augment.py new file mode 100644 index 00000000000..210495e6766 --- /dev/null +++ b/dali/python/nvidia/dali/auto_aug/trivial_augment.py @@ -0,0 +1,195 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import List, Optional + +from nvidia.dali import fn +from nvidia.dali import types +from nvidia.dali.auto_aug import augmentations as a +from nvidia.dali.auto_aug.core import _Augmentation, signed_bin +from nvidia.dali.auto_aug.core._args import \ + forbid_unused_kwargs as _forbid_unused_kwargs +from nvidia.dali.auto_aug.core._utils import \ + parse_validate_offset as _parse_validate_offset, \ + pretty_select as _pretty_select +from nvidia.dali.data_node import DataNode as _DataNode + + +def trivial_augment_wide(sample: _DataNode, num_magnitude_bins: int = 31, + shape: Optional[_DataNode] = None, fill_value: Optional[int] = 128, + interp_type: Optional[types.DALIInterpType] = None, + max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None, seed: Optional[int] = None, + excluded: Optional[List[str]] = None) -> _DataNode: + """ + Applies TrivialAugment Wide (https://arxiv.org/abs/2103.10158) augmentation scheme to the + provided batch of samples. + + Parameter + --------- + sample : DataNode + A batch of samples to be processed. The samples should be images of `HWC` layout, + `uint8` type and reside on GPU. + num_magnitude_bins: int, optional + The number of bins to divide the magnitude ranges into. + fill_value: int, optional + A value to be used as a padding for images transformed with warp_affine ops + (translation, shear and rotate). If `None` is specified, the images are padded + with the border value repeated (clamped). + interp_type: types.DALIInterpType, optional + Interpolation method used by the warp_affine ops (translation, shear and rotate). + Supported values are `types.INTERP_LINEAR` (default) and `types.INTERP_NN`. + seed: int, optional + Seed to be used to randomly sample operations (and to negate magnitudes). + excluded: List[str], optional + A list of names of the operations to be excluded from the default suite of augmentations. + If, instead of just limiting the set of operations, you need to include some custom + operations or fine-tuned of the existing ones, you can use the `apply_trivial_augment` + directly, which accepts a list of augmentations. + + Returns + ------- + DataNode + A batch of transformed samples. + """ + aug_kwargs = {"fill_value": fill_value, "interp_type": interp_type} + use_shape = shape is not None + if use_shape: + aug_kwargs["shape"] = shape + augmentations = get_trivial_augment_wide_suite(use_shape=use_shape, + max_translate_abs=max_translate_abs, + max_translate_rel=max_translate_rel) + augmentation_names = set(aug.name for aug in augmentations) + assert len(augmentation_names) == len(augmentations) + excluded = excluded or [] + for name in excluded: + if name not in augmentation_names: + raise Exception( + f"The `{name}` was specified in `excluded`, but the TrivialAugmentWide suite " + f"does not contain augmentation with this name. " + f"The augmentations in the suite are: {', '.join(augmentation_names)}.") + selected_augments = [aug for aug in augmentations if aug.name not in excluded] + return apply_trivial_augment(selected_augments, sample, num_magnitude_bins=num_magnitude_bins, + seed=seed, **aug_kwargs) + + +def apply_trivial_augment(augmentations: List[_Augmentation], sample: _DataNode, + num_magnitude_bins: int = 31, seed: Optional[int] = None, + **kwargs) -> _DataNode: + """ + Applies the list of `augmentations` in TrivialAugment + (https://arxiv.org/abs/2103.10158) fashion. + Each sample is processed with randomly selected transformation form `augmentations` list. + The magnitude bin for every transformation is randomly selected from + `[0, num_magnitude_bins - 1]`. + + Parameter + --------- + augmentations : List[core._Augmentation] + List of augmentations to be sampled and applied in TrivialAugment fashion. + sample : DataNode + A batch of samples to be processed. The samples should be images of `HWC` layout, + `uint8` type and reside on GPU. + num_magnitude_bins: int, optional + The number of bins to divide the magnitude ranges into. + seed: int, optional + Seed to be used to randomly sample operations (and to negate magnitudes). + kwargs: + Any extra parameters to be passed when calling `augmentations`. + The signature of each augmentation is checked for any extra arguments and if + the name of the argument matches one from the `kwargs`, the value is + passed as an argument. For example, some augmentations from the default + random augment suite accept `shapes`, `fill_value` and `interp_type`. + + Returns + ------- + DataNode + A batch of transformed samples. + """ + if not isinstance(num_magnitude_bins, int) or num_magnitude_bins < 1: + raise Exception( + f"The `num_magnitude_bins` must be a positive integer, got {num_magnitude_bins}.") + if len(augmentations) == 0: + raise Exception("The `augmentations` list cannot be empty. " + "Got empty list in `apply_trivial_augment` call.") + magnitude_bin = fn.random.uniform(values=list(range(num_magnitude_bins)), dtype=types.INT32, + seed=seed) + use_signed_magnitudes = any(aug.randomly_negate for aug in augmentations) + if use_signed_magnitudes: + magnitude_bin = signed_bin(magnitude_bin) + _forbid_unused_kwargs(augmentations, kwargs, 'apply_trivial_augment') + op_kwargs = dict(sample=sample, magnitude_bin=magnitude_bin, + num_magnitude_bins=num_magnitude_bins, **kwargs) + op_idx = fn.random.uniform(values=list(range(len(augmentations))), seed=seed, dtype=types.INT32) + return _pretty_select(augmentations, op_idx, op_kwargs, auto_aug_name='apply_trivial_augment', + ref_suite_name='get_trivial_augment_wide_suite') + + +def get_trivial_augment_wide_suite( + use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None) -> List[_Augmentation]: + """ + Creates a list of 14 augmentations referred as wide augmentation space in TrivialAugment paper + (https://arxiv.org/abs/2103.10158). + + Parameter + --------- + use_shape : bool + If true, the translation offset is computed as a percentage of the image. Useful if the + images processed with the auto augment have different shapes. If false, the offsets range + is bounded by a constant (`max_translate_abs`). + max_translate_abs: int or (int, int), optional + Only valid with use_shape=False, specifies the maximal shift (in pixels) in the translation + augmentations. If tuple is specified, the first component limits height, the second the + width. + max_translate_rel: float or (float, float), optional + Only valid with use_shape=True, specifies the maximal shift as a fraction of image shape + in the translation augmentations. If tuple is specified, the first component limits + height, the second the width. + """ + # translations = [translate_x, translate_y] with adjusted magnitude range + translations = _get_translations(use_shape, max_translate_abs, max_translate_rel) + # [.augmentation((mag_low, mag_high), randomly_negate_mag, custom_magnitude_to_param_mapping] + return translations + [ + a.shear_x.augmentation((0, 0.99), True), + a.shear_y.augmentation((0, 0.99), True), + a.rotate.augmentation((0, 135), True), + a.brightness.augmentation((0.01, 0.99), True, a.shift_enhance_range), + a.contrast.augmentation((0.01, 0.99), True, a.shift_enhance_range), + a.color.augmentation((0.01, 0.99), True, a.shift_enhance_range), + a.sharpness.augmentation((0.01, 0.99), True, a.sharpness_kernel), + a.posterize.augmentation((8, 2), False, a.poster_mask_uint8), + # solarization strength increases with decreasing magnitude (threshold) + a.solarize.augmentation((256, 0)), + a.equalize, + a.auto_contrast, + a.identity, + ] + + +def _get_translations(use_shape: bool = False, max_translate_abs: Optional[int] = None, + max_translate_rel: Optional[float] = None) -> List[_Augmentation]: + max_translate_height, max_translate_width = _parse_validate_offset( + use_shape, max_translate_abs=max_translate_abs, max_translate_rel=max_translate_rel, + default_translate_abs=32, default_translate_rel=1.) + if use_shape: + return [ + a.translate_x.augmentation((0, max_translate_width), True), + a.translate_y.augmentation((0, max_translate_height), True), + ] + else: + return [ + a.translate_x_no_shape.augmentation((0, max_translate_width), True), + a.translate_y_no_shape.augmentation((0, max_translate_height), True), + ] diff --git a/dali/test/python/auto_aug/test_rand_augment.py b/dali/test/python/auto_aug/test_rand_augment.py new file mode 100644 index 00000000000..a5b84fad43e --- /dev/null +++ b/dali/test/python/auto_aug/test_rand_augment.py @@ -0,0 +1,250 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os + +import numpy as np +from scipy.stats import chisquare +from nose2.tools import params + +from nvidia.dali import fn, types +from nvidia.dali.pipeline import experimental +from nvidia.dali.auto_aug import rand_augment +from nvidia.dali.auto_aug.core import augmentation + +from test_utils import get_dali_extra_path +from nose_utils import assert_raises + +data_root = get_dali_extra_path() +images_dir = os.path.join(data_root, 'db', 'single', 'jpeg') + + +@params(*tuple(enumerate(itertools.product((True, False), (True, False), (None, 0), + (True, False))))) +def test_run_rand_aug(i, args): + uniformly_resized, use_shape, fill_value, specify_translation_bounds = args + batch_sizes = [1, 8, 7, 64, 13, 64, 128] + ns = [1, 2, 3, 4] + ms = [0, 15, 30] + batch_size = batch_sizes[i % len(batch_sizes)] + n = ns[i % len(ns)] + m = ms[i % len(ms)] + + @experimental.pipeline_def(enable_conditionals=True, batch_size=batch_size, num_threads=4, + device_id=0, seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + if uniformly_resized: + image = fn.resize(image, size=(244, 244)) + extra = {} if not use_shape else {"shape": fn.peek_image_shape(encoded_image)} + if fill_value is not None: + extra["fill_value"] = fill_value + if specify_translation_bounds: + if use_shape: + extra["max_translate_rel"] = 0.9 + else: + extra["max_translate_abs"] = 400 + image = rand_augment.rand_augment(image, n=n, m=m, **extra) + return image + + p = pipeline() + p.build() + for _ in range(3): + p.run() + + +@params(*tuple(itertools.product((True, False), (0, 1), ('height', 'width', 'both')))) +def test_translation(use_shape, offset_fraction, extent): + # make sure the translation helper processes the args properly + # note, it only uses translate_y (as it is in imagenet policy) + shape = [300, 400] + fill_value = 105 + params = {} + if use_shape: + param = offset_fraction + param_name = "max_translate_rel" + else: + param_name = "max_translate_abs" + assert extent in ('height', 'width', 'both'), f"{extent}" + if extent == 'both': + param = [shape[0] * offset_fraction, shape[1] * offset_fraction] + elif extent == 'height': + param = [shape[0] * offset_fraction, 0] + elif extent == 'width': + param = [0, shape[1] * offset_fraction] + params[param_name] = param + translate_x, translate_y = rand_augment._get_translations(use_shape=use_shape, **params) + if extent == 'both': + augments = [translate_x, translate_y] + elif extent == 'height': + augments = [translate_y] + elif extent == 'width': + augments = [translate_x] + + @experimental.pipeline_def(enable_conditionals=True, batch_size=3, num_threads=4, device_id=0, + seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + image = fn.resize(image, size=shape) + if use_shape: + return rand_augment.apply_rand_augment(augments, image, n=1, m=30, + fill_value=fill_value, shape=shape) + else: + return rand_augment.apply_rand_augment(augments, image, n=1, m=30, + fill_value=fill_value) + + p = pipeline() + p.build() + output, = p.run() + output = [np.array(sample) for sample in output.as_cpu()] + for i, sample in enumerate(output): + sample = np.array(sample) + if offset_fraction == 1: + assert np.all(sample == fill_value), f"sample_idx: {i}" + else: + background_count = np.sum(sample == fill_value) + assert background_count / sample.size < 0.1, \ + f"sample_idx: {i}, {background_count / sample.size}" + + +@params(*tuple(enumerate(itertools.product( + ['cpu', 'gpu'], + [True, False], + [1, 2, 3], + [2, 3], +)))) +def test_ops_selection_and_mags(case_idx, args): + + dev, use_sign, n, num_ops = args + num_magnitude_bins = 9 + # the chisquare expects at least 5 elements in a bin and we can have around + # (num_ops * (2**use_signs)) ** n ops + batch_size = 2048 + magnitude_cases = list(range(num_magnitude_bins)) + m = magnitude_cases[case_idx % len(magnitude_cases)] + + def as_param_with_op_id(op_id): + + def as_param(magnitude): + return np.array([op_id, magnitude], dtype=np.int32) + + return as_param + + @augmentation(param_device=dev) + def op(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + augmentations = [ + op.augmentation(mag_range=(10 * i + 1, 10 * i + num_magnitude_bins), + as_param=as_param_with_op_id(i + 1), randomly_negate=use_sign + and i % 3 == 0) for i in range(num_ops) + ] + + expected_counts = {} + seq_prob = 1. / (num_ops**n) + for aug_sequence in itertools.product(*([augmentations] * n)): + possible_signs = [(-1, 1) if aug.randomly_negate else (1, ) for aug in aug_sequence] + possible_signs = tuple(itertools.product(*possible_signs)) + prob = seq_prob / len(possible_signs) + for signs in possible_signs: + assert len(aug_sequence) == len(signs) + outs = [] + for aug, sign in zip(aug_sequence, signs): + mag = aug._get_magnitudes(num_magnitude_bins)[m] + op_id_mag = aug.as_param(mag * sign) + outs.append(op_id_mag) + expected_counts[tuple(el for out in outs for el in out)] = prob + expected_counts = {output: p * batch_size for output, p in expected_counts.items()} + + @experimental.pipeline_def(enable_conditionals=True, batch_size=batch_size, num_threads=4, + device_id=0, seed=42) + def pipeline(): + sample = types.Constant([], dtype=types.INT32) + if dev == "gpu": + sample = sample.gpu() + sample = rand_augment.apply_rand_augment(augmentations, sample, n=n, m=m, + num_magnitude_bins=num_magnitude_bins) + return fn.reshape(sample, shape=(-1, 2)) + + p = pipeline() + p.build() + for i in range(3): + output, = p.run() + output = [np.array(s) for s in (output.as_cpu() if dev == "gpu" else output)] + actual_count = {allowed_out: 0 for allowed_out in expected_counts} + for sample in output: + assert len(sample) == n, f"{i} {sample}" + out = tuple(el for op_mag in sample for el in op_mag) + actual_count[out] += 1 + actual = [] + expected = [] + for out in expected_counts: + actual.append(actual_count[out]) + expected.append(expected_counts[out]) + stat = chisquare(actual, expected) + assert 0.01 <= stat.pvalue <= 0.99, f"{stat} {actual} {expected}" + + +def test_wrong_params_fail(): + + @experimental.pipeline_def(batch_size=4, device_id=0, num_threads=4, seed=42, + enable_conditionals=True) + def pipeline(n, m, num_magnitude_bins): + sample = types.Constant(np.array([[[]]], dtype=np.uint8)) + return rand_augment.rand_augment(sample, n=n, m=m, num_magnitude_bins=num_magnitude_bins) + + with assert_raises(Exception, + glob="The number of operations to apply `n` must be a non-negative integer"): + pipeline(n=None, m=1, num_magnitude_bins=11) + + with assert_raises(Exception, glob="The `num_magnitude_bins` must be a positive integer, got"): + pipeline(n=1, m=1, num_magnitude_bins=None) + + with assert_raises(Exception, glob="`m` must be an integer from `[[]0, 14[]]` range. Got 15."): + pipeline(n=1, m=15, num_magnitude_bins=15) + + with assert_raises(Exception, glob="The `augmentations` list cannot be empty"): + + @experimental.pipeline_def(batch_size=4, device_id=0, num_threads=4, seed=42, + enable_conditionals=True) + def no_aug_pipeline(): + sample = types.Constant(np.array([[[]]], dtype=np.uint8)) + return rand_augment.apply_rand_augment([], sample, 1, 20) + + no_aug_pipeline() + + with assert_raises(Exception, glob="The augmentation `translate_x` requires `shape` argument"): + + @experimental.pipeline_def(batch_size=4, device_id=0, num_threads=4, seed=42, + enable_conditionals=True) + def missing_shape(): + sample = types.Constant(np.array([[[]]], dtype=np.uint8)) + augments = rand_augment.get_rand_augment_suite(use_shape=True) + return rand_augment.apply_rand_augment(augments, sample, 1, 20) + + missing_shape() + + with assert_raises(Exception, glob="The kwarg `shhape` is not used by any of the"): + + @experimental.pipeline_def(batch_size=4, device_id=0, num_threads=4, seed=42, + enable_conditionals=True) + def unused_kwarg(): + sample = types.Constant(np.array([[[]]], dtype=np.uint8)) + augments = rand_augment.get_rand_augment_suite(use_shape=True) + return rand_augment.apply_rand_augment(augments, sample, 1, 20, shhape=42) + + unused_kwarg() diff --git a/dali/test/python/auto_aug/test_trivial_augment.py b/dali/test/python/auto_aug/test_trivial_augment.py new file mode 100644 index 00000000000..dc69ed4ebac --- /dev/null +++ b/dali/test/python/auto_aug/test_trivial_augment.py @@ -0,0 +1,182 @@ +# Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import itertools +import os + +import numpy as np +from scipy.stats import chisquare +from nose2.tools import params + +from nvidia.dali import fn, types +from nvidia.dali.pipeline import experimental +from nvidia.dali.auto_aug import trivial_augment +from nvidia.dali.auto_aug.core import augmentation +from test_utils import get_dali_extra_path + +data_root = get_dali_extra_path() +images_dir = os.path.join(data_root, 'db', 'single', 'jpeg') + + +@params(*tuple(enumerate(itertools.product((True, False), (True, False), (None, 0), + (True, False))))) +def test_run_trivial(i, args): + uniformly_resized, use_shape, fill_value, specify_translation_bounds = args + batch_sizes = [1, 8, 7, 64, 13, 64, 128] + num_magnitude_bin_cases = [1, 11, 31, 40] + batch_size = batch_sizes[i % len(batch_sizes)] + num_magnitude_bins = num_magnitude_bin_cases[i % len(num_magnitude_bin_cases)] + + @experimental.pipeline_def(enable_conditionals=True, batch_size=batch_size, num_threads=4, + device_id=0, seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + if uniformly_resized: + image = fn.resize(image, size=(244, 244)) + extra = {} if not use_shape else {"shape": fn.peek_image_shape(encoded_image)} + if fill_value is not None: + extra["fill_value"] = fill_value + if specify_translation_bounds: + if use_shape: + extra["max_translate_rel"] = 0.9 + else: + extra["max_translate_abs"] = 400 + image = trivial_augment.trivial_augment_wide(image, num_magnitude_bins=num_magnitude_bins, + **extra) + return image + + p = pipeline() + p.build() + for _ in range(3): + p.run() + + +@params(*tuple(itertools.product((True, False), (0, 1), ('x', 'y')))) +def test_translation(use_shape, offset_fraction, extent): + # make sure the translation helper processes the args properly + # note, it only uses translate_y (as it is in imagenet policy) + fill_value = 0 + params = {} + if use_shape: + param = offset_fraction + param_name = "max_translate_rel" + else: + param = 1000 * offset_fraction + param_name = "max_translate_abs" + params[param_name] = param + translation_x, translation_y = trivial_augment._get_translations(use_shape=use_shape, **params) + augment = [translation_x] if extent == 'x' else [translation_y] + + @experimental.pipeline_def(enable_conditionals=True, batch_size=9, num_threads=4, device_id=0, + seed=43) + def pipeline(): + encoded_image, _ = fn.readers.file(name="Reader", file_root=images_dir) + image = fn.decoders.image(encoded_image, device="mixed") + if use_shape: + shape = fn.peek_image_shape(encoded_image) + return trivial_augment.apply_trivial_augment(augment, image, num_magnitude_bins=3, + fill_value=fill_value, shape=shape) + else: + return trivial_augment.apply_trivial_augment(augment, image, num_magnitude_bins=3, + fill_value=fill_value) + + p = pipeline() + p.build() + output, = p.run() + output = [np.array(sample) for sample in output.as_cpu()] + if offset_fraction == 1: + # magnitudes are random here, but some should randomly be maximal + all_black = 0 + for i, sample in enumerate(output): + sample = np.array(sample) + all_black += np.all(sample == fill_value) + assert all_black + else: + for i, sample in enumerate(output): + sample = np.array(sample) + background_count = np.sum(sample == fill_value) + assert background_count / sample.size < 0.1, \ + f"sample_idx: {i}, {background_count / sample.size}" + + +@params(*tuple(itertools.product( + ['cpu', 'gpu'], + [True, False], + [1, 3, 7], + [2, 3, 7], +))) +def test_ops_mags_selection(dev, use_sign, num_magnitude_bins, num_ops): + # the chisquare expects at least 5 elements in a bin and we can have around + # num_magnitude_bins * num_ops * (2**use_signs) + batch_size = 2048 + + def as_param_with_op_id(op_id): + + def as_param(magnitude): + return np.array([op_id, magnitude], dtype=np.int32) + + return as_param + + @augmentation(param_device=dev) + def op(sample, op_id_mag_id): + return fn.cat(sample, op_id_mag_id) + + augmentations = [ + op.augmentation(mag_range=(10 * i + 1, 10 * i + num_magnitude_bins), + as_param=as_param_with_op_id(i + 1), randomly_negate=use_sign + and i % 3 == 0) for i in range(num_ops) + ] + + expected_counts = {} + prob = 1. / (num_ops * num_magnitude_bins) + for aug in augmentations: + magnitudes = aug._get_magnitudes(num_magnitude_bins) + assert len(magnitudes) == num_magnitude_bins + for mag in magnitudes: + if not aug.randomly_negate: + expected_counts[tuple(aug.as_param(mag))] = prob + else: + expected_counts[tuple(aug.as_param(mag))] = prob / 2 + expected_counts[tuple(aug.as_param(-mag))] = prob / 2 + expected_counts = {output: p * batch_size for output, p in expected_counts.items()} + + @experimental.pipeline_def(enable_conditionals=True, batch_size=batch_size, num_threads=4, + device_id=0, seed=42) + def pipeline(): + sample = types.Constant([], dtype=types.INT32) + if dev == "gpu": + sample = sample.gpu() + sample = trivial_augment.apply_trivial_augment(augmentations, sample, + num_magnitude_bins=num_magnitude_bins) + return sample + + p = pipeline() + p.build() + stats = [] + for i in range(3): + output, = p.run() + output = [np.array(s) for s in (output.as_cpu() if dev == "gpu" else output)] + actual_count = {allowed_out: 0 for allowed_out in expected_counts} + for sample in output: + actual_count[tuple(sample)] += 1 + actual = [] + expected = [] + for out in expected_counts: + actual.append(actual_count[out]) + expected.append(expected_counts[out]) + stat = chisquare(actual, expected) + stats.append(stat) + mean_p_val = sum(stat.pvalue for stat in stats) / len(stats) + assert 0.05 <= mean_p_val <= 0.95, f"{mean_p_val} {stat} {actual} {expected}"