diff --git a/dali/python/nvidia/dali/auto_aug/auto_augment.py b/dali/python/nvidia/dali/auto_aug/auto_augment.py index 6ab3b26e793..5e2a694e0b0 100644 --- a/dali/python/nvidia/dali/auto_aug/auto_augment.py +++ b/dali/python/nvidia/dali/auto_aug/auto_augment.py @@ -17,7 +17,7 @@ from nvidia.dali import fn from nvidia.dali import types from nvidia.dali.auto_aug import augmentations as a -from nvidia.dali.auto_aug.core.utils import operation_idx_random_choice, select +from nvidia.dali.auto_aug.core.utils import select class Policy: @@ -160,13 +160,12 @@ def apply_auto_augment(policy: Policy, samples, seed=None, augment_kwargs=None): if not use_signed_magnitudes: random_sign = None else: - random_sign = fn.random.uniform(range=[0, 1], dtype=types.INT32, seed=seed, - shape=(max_policy_len, )) + random_sign = fn.random.uniform(values=[0, 1], seed=seed, shape=(max_policy_len, )) should_run = fn.random.uniform(range=[0, 1], shape=(max_policy_len, ), dtype=types.FLOAT) op_kwargs = dict(samples=samples, should_run=should_run, random_sign=random_sign, num_magnitude_bins=policy.num_magnitude_bins, **augment_kwargs) sub_policies = [apply_sub_policy(sub_policy) for sub_policy in sub_policies] - policy_id = operation_idx_random_choice(len(sub_policies), 1, seed) + policy_id = fn.random.uniform(values=list(range(len(sub_policies))), seed=seed) return select(sub_policies, policy_id, op_kwargs) diff --git a/dali/python/nvidia/dali/auto_aug/core/utils.py b/dali/python/nvidia/dali/auto_aug/core/utils.py index 0981ee57aa6..188985e52a7 100644 --- a/dali/python/nvidia/dali/auto_aug/core/utils.py +++ b/dali/python/nvidia/dali/auto_aug/core/utils.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from nvidia.dali import fn - try: import numpy as np except ImportError: @@ -33,16 +31,6 @@ def remap_bin_idx(bin_idx): return np.array([remap_bin_idx(bin_idx) for bin_idx in range(2 * len(magnitudes))]) -def operation_idx_random_choice(num_total_ops, num_levels=1, rng_seed=42): - shape = tuple() if num_levels == 1 else (num_levels, ) - rng = np.random.default_rng(rng_seed) - - def random_choice(_): - return rng.choice(range(num_total_ops), shape) - - return fn.external_source(source=random_choice, batch=False) - - def split_samples_among_ops(op_range_lo, op_range_hi, ops, selected_op_idx, op_kwargs): assert op_range_lo <= op_range_hi if op_range_lo == op_range_hi: diff --git a/dali/python/nvidia/dali/auto_aug/rand_augment.py b/dali/python/nvidia/dali/auto_aug/rand_augment.py index 61784ad6821..e6b65bd9f1f 100644 --- a/dali/python/nvidia/dali/auto_aug/rand_augment.py +++ b/dali/python/nvidia/dali/auto_aug/rand_augment.py @@ -15,7 +15,7 @@ from nvidia.dali import fn from nvidia.dali import types from nvidia.dali.auto_aug import augmentations as a -from nvidia.dali.auto_aug.core.utils import operation_idx_random_choice, select +from nvidia.dali.auto_aug.core.utils import select rand_augment_suite = { "shear_x": a.shear_x.augmentation((0, 0.3), True), @@ -152,12 +152,12 @@ def apply_rand_augment(augmentations, samples, n, m, num_magnitude_bins, seed, a return samples augment_kwargs = augment_kwargs or {} use_signed_magnitudes = any(aug.randomly_negate for aug in augmentations) + shape = tuple() if n == 1 else (n, ) if not use_signed_magnitudes: random_sign = None else: - random_sign = fn.random.uniform(range=[0, 1], dtype=types.INT32, seed=seed, - shape=tuple() if n == 1 else (n, )) - op_idx = operation_idx_random_choice(len(augmentations), n, seed) + random_sign = fn.random.uniform(values=[0, 1], seed=seed, shape=shape) + op_idx = fn.random.uniform(values=list(range(len(augmentations))), seed=seed, shape=shape) for level_idx in range(n): if not use_signed_magnitudes or n == 1: level_random_sign = random_sign diff --git a/dali/python/nvidia/dali/auto_aug/trivial_augment.py b/dali/python/nvidia/dali/auto_aug/trivial_augment.py index 95ca51b0fc7..7ad9567211f 100644 --- a/dali/python/nvidia/dali/auto_aug/trivial_augment.py +++ b/dali/python/nvidia/dali/auto_aug/trivial_augment.py @@ -15,7 +15,7 @@ from nvidia.dali import fn from nvidia.dali import types from nvidia.dali.auto_aug import augmentations as a -from nvidia.dali.auto_aug.core.utils import operation_idx_random_choice, select +from nvidia.dali.auto_aug.core.utils import select trivial_augment_wide_suite = { "shear_x": a.shear_x.augmentation((0, 0.99), True), @@ -129,5 +129,5 @@ def apply_trivial_augment(augmentations, samples, num_magnitude_bins, seed, augm op_kwargs = dict(samples=samples, magnitude_bin_idx=magnitude_bin_idx, num_magnitude_bins=num_magnitude_bins, random_sign=random_sign, **augment_kwargs) - op_idx = operation_idx_random_choice(len(augmentations), 1, seed) + op_idx = fn.random.uniform(values=list(range(len(augmentations))), seed=seed) return select(augmentations, op_idx, op_kwargs)