From 51d7938538773e40ce06b70362668d7419836304 Mon Sep 17 00:00:00 2001 From: James Martens Date: Sun, 15 May 2022 06:03:54 -0700 Subject: [PATCH] Changing Imagenet dataset in examples to use a seed for file shuffling to achieve determinism. PiperOrigin-RevId: 448791468 --- examples/datasets.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/examples/datasets.py b/examples/datasets.py index 27e47ae..0aa46d3 100644 --- a/examples/datasets.py +++ b/examples/datasets.py @@ -145,9 +145,9 @@ def imagenet_num_examples_and_split( def imagenet_dataset( split: str, - seed: Optional[int], is_training: bool, batch_dims: chex.Shape, + seed: int = 123, shuffle_files: bool = True, buffer_size_factor: int = 10, shuffle: bool = False, @@ -160,9 +160,10 @@ def imagenet_dataset( Args: split: Which data split to load. - seed: Any seed to use for random pre-processing. is_training: Whether this is on the training or evaluator worker. batch_dims: The shape of the batch dimensions. + seed: Any seed to use for random pre-processing, shuffling, and file + shuffling. shuffle_files: Whether to shuffle the ImageNet files. buffer_size_factor: Batch size factor for computing cache size. shuffle: Whether to shuffle the cache. @@ -175,9 +176,11 @@ def imagenet_dataset( The ImageNet dataset as a tensorflow dataset. """ - if is_training and seed is None: - raise ValueError("You need to provide seed when doing training " - "pre-processing.") + preprocess_seed = seed + shuffle_seed = seed + 1 + file_shuffle_seed = seed + 2 + del seed + num_examples, tfds_split = imagenet_num_examples_and_split(split) shard_range = np.array_split(np.arange(num_examples), @@ -194,34 +197,37 @@ def imagenet_dataset( tfds_split = tfds.core.ReadInstruction( tfds_split, from_=start, to=end, unit="abs") + read_config = tfds.ReadConfig(shuffle_seed=file_shuffle_seed) + + read_config.options.threading.private_threadpool_size = 48 + read_config.options.threading.max_intra_op_parallelism = 1 + read_config.options.deterministic = True + ds = tfds.load( name="imagenet2012:5.*.*", shuffle_files=shuffle_files, split=tfds_split, decoders={"image": tfds.decode.SkipDecoding()}, data_dir=data_dir, + read_config=read_config, ) - options = tf.data.Options() - options.threading.private_threadpool_size = 48 - options.threading.max_intra_op_parallelism = 1 - options.experimental_optimization.map_parallelization = True - options.experimental_deterministic = True - ds = ds.with_options(options) if is_training: if cache: ds = ds.cache() ds = ds.repeat() if shuffle: - ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, seed=0) + ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, + seed=shuffle_seed) elif num_examples % total_batch_size != 0: # If the dataset is not divisible by the batch size then just randomize if shuffle: - ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, seed=0) + ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, + seed=shuffle_seed) if is_training: - rng = jax.random.PRNGKey(seed) + rng = jax.random.PRNGKey(preprocess_seed) tf_seed = tf.convert_to_tensor(rng, dtype=tf.int32) # When training we generate a stateless pipeline, at test we don't need it