Skip to content

Commit

Permalink
Changing Imagenet dataset in examples to use a seed for file shufflin…
Browse files Browse the repository at this point in the history
…g to achieve determinism.

PiperOrigin-RevId: 448791468
  • Loading branch information
james-martens authored and KfacJaxDev committed May 17, 2022
1 parent 274e8cc commit 1b7f9f2
Showing 1 changed file with 20 additions and 14 deletions.
34 changes: 20 additions & 14 deletions examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,9 +145,9 @@ def imagenet_num_examples_and_split(

def imagenet_dataset(
split: str,
seed: Optional[int],
is_training: bool,
batch_dims: chex.Shape,
seed: int = 123,
shuffle_files: bool = True,
buffer_size_factor: int = 10,
shuffle: bool = False,
Expand All @@ -160,9 +160,10 @@ def imagenet_dataset(
Args:
split: Which data split to load.
seed: Any seed to use for random pre-processing.
is_training: Whether this is on the training or evaluator worker.
batch_dims: The shape of the batch dimensions.
seed: Any seed to use for random pre-processing, shuffling, and file
shuffling.
shuffle_files: Whether to shuffle the ImageNet files.
buffer_size_factor: Batch size factor for computing cache size.
shuffle: Whether to shuffle the cache.
Expand All @@ -175,9 +176,11 @@ def imagenet_dataset(
The ImageNet dataset as a tensorflow dataset.
"""

if is_training and seed is None:
raise ValueError("You need to provide seed when doing training "
"pre-processing.")
preprocess_seed = seed
shuffle_seed = seed + 1
file_shuffle_seed = seed + 2
del seed

num_examples, tfds_split = imagenet_num_examples_and_split(split)

shard_range = np.array_split(np.arange(num_examples),
Expand All @@ -194,34 +197,37 @@ def imagenet_dataset(
tfds_split = tfds.core.ReadInstruction(
tfds_split, from_=start, to=end, unit="abs")

read_config = tfds.ReadConfig(shuffle_seed=file_shuffle_seed)

read_config.options.threading.private_threadpool_size = 48
read_config.options.threading.max_intra_op_parallelism = 1
read_config.options.deterministic = True

ds = tfds.load(
name="imagenet2012:5.*.*",
shuffle_files=shuffle_files,
split=tfds_split,
decoders={"image": tfds.decode.SkipDecoding()},
data_dir=data_dir,
read_config=read_config,
)
options = tf.data.Options()
options.threading.private_threadpool_size = 48
options.threading.max_intra_op_parallelism = 1
options.experimental_optimization.map_parallelization = True
options.experimental_deterministic = True
ds = ds.with_options(options)

if is_training:
if cache:
ds = ds.cache()
ds = ds.repeat()
if shuffle:
ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, seed=0)
ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size,
seed=shuffle_seed)

elif num_examples % total_batch_size != 0:
# If the dataset is not divisible by the batch size then just randomize
if shuffle:
ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, seed=0)
ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size,
seed=shuffle_seed)

if is_training:
rng = jax.random.PRNGKey(seed)
rng = jax.random.PRNGKey(preprocess_seed)
tf_seed = tf.convert_to_tensor(rng, dtype=tf.int32)

# When training we generate a stateless pipeline, at test we don't need it
Expand Down

0 comments on commit 1b7f9f2

Please sign in to comment.