Skip to content

Commit

Permalink
Use random.uniform discrete mode
Browse files Browse the repository at this point in the history
Signed-off-by: Kamil Tokarski <[email protected]>
  • Loading branch information
stiepan committed Feb 13, 2023
1 parent 4dc3119 commit f3fa6f0
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 22 deletions.
7 changes: 3 additions & 4 deletions dali/python/nvidia/dali/auto_aug/auto_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from nvidia.dali import fn
from nvidia.dali import types
from nvidia.dali.auto_aug import augmentations as a
from nvidia.dali.auto_aug.core.utils import operation_idx_random_choice, select
from nvidia.dali.auto_aug.core.utils import select


class Policy:
Expand Down Expand Up @@ -160,13 +160,12 @@ def apply_auto_augment(policy: Policy, samples, seed=None, augment_kwargs=None):
if not use_signed_magnitudes:
random_sign = None
else:
random_sign = fn.random.uniform(range=[0, 1], dtype=types.INT32, seed=seed,
shape=(max_policy_len, ))
random_sign = fn.random.uniform(values=[0, 1], seed=seed, shape=(max_policy_len, ))
should_run = fn.random.uniform(range=[0, 1], shape=(max_policy_len, ), dtype=types.FLOAT)
op_kwargs = dict(samples=samples, should_run=should_run, random_sign=random_sign,
num_magnitude_bins=policy.num_magnitude_bins, **augment_kwargs)
sub_policies = [apply_sub_policy(sub_policy) for sub_policy in sub_policies]
policy_id = operation_idx_random_choice(len(sub_policies), 1, seed)
policy_id = fn.random.uniform(values=list(range(len(sub_policies))), seed=seed)
return select(sub_policies, policy_id, op_kwargs)


Expand Down
12 changes: 0 additions & 12 deletions dali/python/nvidia/dali/auto_aug/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from nvidia.dali import fn

try:
import numpy as np
except ImportError:
Expand All @@ -33,16 +31,6 @@ def remap_bin_idx(bin_idx):
return np.array([remap_bin_idx(bin_idx) for bin_idx in range(2 * len(magnitudes))])


def operation_idx_random_choice(num_total_ops, num_levels=1, rng_seed=42):
shape = tuple() if num_levels == 1 else (num_levels, )
rng = np.random.default_rng(rng_seed)

def random_choice(_):
return rng.choice(range(num_total_ops), shape)

return fn.external_source(source=random_choice, batch=False)


def split_samples_among_ops(op_range_lo, op_range_hi, ops, selected_op_idx, op_kwargs):
assert op_range_lo <= op_range_hi
if op_range_lo == op_range_hi:
Expand Down
8 changes: 4 additions & 4 deletions dali/python/nvidia/dali/auto_aug/rand_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nvidia.dali import fn
from nvidia.dali import types
from nvidia.dali.auto_aug import augmentations as a
from nvidia.dali.auto_aug.core.utils import operation_idx_random_choice, select
from nvidia.dali.auto_aug.core.utils import select

rand_augment_suite = {
"shear_x": a.shear_x.augmentation((0, 0.3), True),
Expand Down Expand Up @@ -152,12 +152,12 @@ def apply_rand_augment(augmentations, samples, n, m, num_magnitude_bins, seed, a
return samples
augment_kwargs = augment_kwargs or {}
use_signed_magnitudes = any(aug.randomly_negate for aug in augmentations)
shape = tuple() if n == 1 else (n, )
if not use_signed_magnitudes:
random_sign = None
else:
random_sign = fn.random.uniform(range=[0, 1], dtype=types.INT32, seed=seed,
shape=tuple() if n == 1 else (n, ))
op_idx = operation_idx_random_choice(len(augmentations), n, seed)
random_sign = fn.random.uniform(values=[0, 1], seed=seed, shape=shape)
op_idx = fn.random.uniform(values=list(range(len(augmentations))), seed=seed, shape=shape)
for level_idx in range(n):
if not use_signed_magnitudes or n == 1:
level_random_sign = random_sign
Expand Down
4 changes: 2 additions & 2 deletions dali/python/nvidia/dali/auto_aug/trivial_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from nvidia.dali import fn
from nvidia.dali import types
from nvidia.dali.auto_aug import augmentations as a
from nvidia.dali.auto_aug.core.utils import operation_idx_random_choice, select
from nvidia.dali.auto_aug.core.utils import select

trivial_augment_wide_suite = {
"shear_x": a.shear_x.augmentation((0, 0.99), True),
Expand Down Expand Up @@ -129,5 +129,5 @@ def apply_trivial_augment(augmentations, samples, num_magnitude_bins, seed, augm
op_kwargs = dict(samples=samples, magnitude_bin_idx=magnitude_bin_idx,
num_magnitude_bins=num_magnitude_bins, random_sign=random_sign,
**augment_kwargs)
op_idx = operation_idx_random_choice(len(augmentations), 1, seed)
op_idx = fn.random.uniform(values=list(range(len(augmentations))), seed=seed)
return select(augmentations, op_idx, op_kwargs)

0 comments on commit f3fa6f0

Please sign in to comment.