From a1550e33f00a2b5f2ba45a7f66009bf6a75d52ff Mon Sep 17 00:00:00 2001 From: James Martens Date: Wed, 15 Feb 2023 11:07:18 -0800 Subject: [PATCH] - Modifying examples to only use label smoothing and L2 reg loss when training - Adding extra_preprocessing_func argument to some datasets to support things like Per-location Normalization with DKS/TAT - Updating classifier_mnist example to use classifier_loss_and_stats - Adding whitespace to improve readability PiperOrigin-RevId: 509883028 --- examples/autoencoder_mnist/experiment.py | 12 +++-- examples/classifier_mnist/experiment.py | 24 +++++----- examples/datasets.py | 51 ++++++++++++++++++++- examples/lrelunet101_imagenet/experiment.py | 4 +- examples/resnet50_imagenet/experiment.py | 4 +- 5 files changed, 73 insertions(+), 22 deletions(-) diff --git a/examples/autoencoder_mnist/experiment.py b/examples/autoencoder_mnist/experiment.py index 86d6408..cb1ec91 100644 --- a/examples/autoencoder_mnist/experiment.py +++ b/examples/autoencoder_mnist/experiment.py @@ -60,7 +60,7 @@ def autoencoder_loss( average_loss: bool = True, ) -> Tuple[chex.Array, Dict[str, chex.Array]]: """Evaluates the loss of the autoencoder.""" - del is_training # not used + if isinstance(batch, Mapping): batch = batch["images"] @@ -69,16 +69,18 @@ def autoencoder_loss( cross_entropy = jnp.sum(losses.sigmoid_cross_entropy(logits, batch), axis=-1) averaged_cross_entropy = jnp.mean(cross_entropy) - params_l2 = losses.l2_regularizer(params, False, False) loss = averaged_cross_entropy if average_loss else cross_entropy - regularized_loss = loss + l2_reg * params_l2 + + l2_reg_val = losses.l2_regularizer(params, False, False) + if is_training: + loss = loss + l2_reg * l2_reg_val error = nn.sigmoid(logits) - batch.reshape([batch.shape[0], -1]) mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0) - return regularized_loss, dict( + return loss, dict( cross_entropy=averaged_cross_entropy, - regualrizer=params_l2, + l2_reg_val=l2_reg_val, mean_squared_error=mean_squared_error, ) diff --git a/examples/classifier_mnist/experiment.py b/examples/classifier_mnist/experiment.py index 22b0722..f24b89a 100644 --- a/examples/classifier_mnist/experiment.py +++ b/examples/classifier_mnist/experiment.py @@ -58,21 +58,21 @@ def classifier_loss( average_loss: bool = True, ) -> Tuple[chex.Array, Dict[str, chex.Array]]: """Evaluates the loss of the classifier network.""" - del is_training # not used logits = convolutional_classifier().apply(params, batch["images"]) - cross_entropy = losses.softmax_cross_entropy(logits, batch["labels"]) - - if average_loss: - cross_entropy = jnp.mean(cross_entropy) - - params_l2 = losses.l2_regularizer(params, False, False) - regularized_loss = cross_entropy + l2_reg * params_l2 - - accuracy = losses.top_k_accuracy(logits, batch["labels"], 1) - - return regularized_loss, dict(accuracy=accuracy) + loss, stats = losses.classifier_loss_and_stats( + logits=logits, + labels_as_int=batch["labels"], + params=params, + l2_reg=l2_reg if is_training else 0.0, + haiku_exclude_batch_norm=False, + haiku_exclude_biases=False, + average_loss=average_loss, + top_k_stats=(1,), + ) + + return loss, stats class ClassifierMnistExperiment(training.MnistExperiment): diff --git a/examples/datasets.py b/examples/datasets.py index d705ed3..f2f551d 100644 --- a/examples/datasets.py +++ b/examples/datasets.py @@ -15,7 +15,7 @@ """ import types -from typing import Dict, Iterator, Mapping, Optional, Tuple, TypeVar +from typing import Callable, Dict, Iterator, Mapping, Optional, Tuple, TypeVar import chex import jax @@ -69,6 +69,7 @@ def mnist_dataset( Returns: The MNIST dataset as a tensorflow dataset. """ + # Set for multi devices vs single device num_devices = jax.device_count() if multi_device else 1 num_local_devices = jax.local_device_count() if multi_device else 1 @@ -77,6 +78,7 @@ def mnist_dataset( host_batch_shape = [num_local_devices, device_batch_size] else: host_batch_shape = [device_batch_size] + host_batch_size = num_local_devices * device_batch_size num_examples = tfds.builder("mnist").info.splits[split].num_examples @@ -100,9 +102,13 @@ def preprocess_batch( return dict(images=images) ds = tfds.load(name="mnist", split=split, as_supervised=True) + ds = ds.shard(jax.process_count(), jax.process_index()) + ds = ds.cache() + if host_batch_size < num_examples and shuffle: + ds = ds.shuffle(buffer_size=(num_examples // jax.process_count()), seed=seed, reshuffle_each_iteration=reshuffle_each_iteration) @@ -110,6 +116,7 @@ def preprocess_batch( ds = ds.repeat() ds = ds.batch(host_batch_size, drop_remainder=drop_remainder) + ds = ds.map(preprocess_batch, num_parallel_calls=tf.data.experimental.AUTOTUNE) @@ -122,6 +129,7 @@ def imagenet_num_examples_and_split( split: str ) -> Tuple[int, tensorflow_datasets.Split]: """Returns the number of examples in the given split of Imagenet.""" + if split == "train": return 1271167, tensorflow_datasets.Split.TRAIN elif split == "valid": @@ -148,6 +156,9 @@ def imagenet_dataset( dtype: jnp.dtype = jnp.float32, image_size: chex.Shape = (224, 224), data_dir: Optional[str] = None, + extra_preprocessing_func: Optional[ + Callable[[jnp.DeviceArray, jnp.DeviceArray], + Tuple[jnp.DeviceArray, jnp.DeviceArray]]] = None, ) -> Iterator[Batch]: """Standard ImageNet dataset pipeline. @@ -164,6 +175,10 @@ def imagenet_dataset( dtype: The returned data type of the images. image_size: The image sizes. data_dir: If specified, will use this directory to load the dataset from. + extra_preprocessing_func: A callable to perform addition data preprocessing + if desired. Should take arguments `image` and `label` consisting of the + image and its label (without batch dimension), and return a tuple + consisting of the processed version of these two. Returns: The ImageNet dataset as a tensorflow dataset. @@ -206,9 +221,12 @@ def imagenet_dataset( ) if is_training: + if cache: ds = ds.cache() + ds = ds.repeat() + if shuffle: ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size, seed=shuffle_seed) @@ -238,6 +256,7 @@ def preprocess( example: Mapping[str, tf.Tensor], seed_: Optional[tf.Tensor] = None ) -> Dict[str, tf.Tensor]: + image = _imagenet_preprocess_image( image_bytes=example["image"], seed=seed_, @@ -245,19 +264,29 @@ def preprocess( image_size=image_size ) label = tf.cast(example["label"], tf.int32) + + if extra_preprocessing_func is not None: + image, label = extra_preprocessing_func(image, label) + return {"images": image, "labels": label} ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE) def cast_fn(batch_): + tf_dtype = (tf.bfloat16 if dtype == jnp.bfloat16 else tf.dtypes.as_dtype(dtype)) + batch_ = dict(**batch_) + batch_["images"] = tf.cast(batch_["images"], tf_dtype) + return batch_ for i, batch_size in enumerate(reversed(batch_dims)): + ds = ds.batch(batch_size, drop_remainder=not is_training) + if i == 0: # NOTE: You may be tempted to move the casting earlier on in the pipeline, # but for bf16 some operations will end up silently placed on the TPU and @@ -276,23 +305,31 @@ def _imagenet_preprocess_image( image_size: chex.Shape, ) -> tf.Tensor: """Returns processed and resized images for Imagenet.""" + if is_training: seeds = tf.random.experimental.stateless_split(seed, num=2) + # Random cropping of the image image = _decode_and_random_crop( image_bytes, seed=seeds[0], image_size=image_size) + # Random left-right flipping image = tf.image.stateless_random_flip_left_right(image, seed=seeds[1]) + else: image = _decode_and_center_crop(image_bytes, image_size=image_size) + assert image.dtype == tf.uint8 + # NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without # clamping overshoots. This means values returned will be outside the range # [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]). image = tf.image.resize(image, image_size, tf.image.ResizeMethod.BICUBIC) + # Normalize image mean = tf.constant(_IMAGENET_MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype) std = tf.constant(_IMAGENET_STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype) + return (image - mean * 255) / (std * 255) @@ -307,6 +344,7 @@ def _distorted_bounding_box_crop( max_attempts: int, ) -> tf.Tensor: """Generates cropped_image using one of the bboxes randomly distorted for Imagenet.""" + bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box( image_size=jpeg_shape, bounding_boxes=bbox, @@ -322,6 +360,7 @@ def _distorted_bounding_box_crop( offset_y, offset_x, _ = tf.unstack(bbox_begin) target_height, target_width, _ = tf.unstack(bbox_size) crop_window = tf.stack([offset_y, offset_x, target_height, target_width]) + return tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) @@ -354,8 +393,10 @@ def _decode_and_center_crop( image_size: chex.Shape = (224, 224), ) -> tf.Tensor: """Crops to center of image with padding then scales for Imagenet.""" + if jpeg_shape is None: jpeg_shape = tf.image.extract_jpeg_shape(image_bytes) + image_height = jpeg_shape[0] image_width = jpeg_shape[1] @@ -363,12 +404,16 @@ def _decode_and_center_crop( # crop that maintains aspect ratio. scale = tf.minimum(tf.cast(image_height, tf.float32) / (image_size[0] + 32), tf.cast(image_width, tf.float32) / (image_size[1] + 32)) + padded_center_crop_height = tf.cast(scale * image_size[0], tf.int32) padded_center_crop_width = tf.cast(scale * image_size[1], tf.int32) + offset_height = ((image_height - padded_center_crop_height) + 1) // 2 offset_width = ((image_width - padded_center_crop_width) + 1) // 2 + crop_window = tf.stack([offset_height, offset_width, padded_center_crop_height, padded_center_crop_width]) + return tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) @@ -378,7 +423,9 @@ def _imagenet_distort_color( color_ordering: int = 0, ) -> tf.Tensor: """Randomly distorts colors for Imagenet.""" + seeds = tf.random.experimental.stateless_split(seed, num=4) + if color_ordering == 0: image = tf.image.stateless_random_brightness( image, max_delta=32. / 255., seed=seeds[0]) @@ -388,6 +435,7 @@ def _imagenet_distort_color( image, max_delta=0.2, seed=seeds[2]) image = tf.image.stateless_random_contrast( image, lower=0.5, upper=1.5, seed=seeds[3]) + elif color_ordering == 1: image = tf.image.stateless_random_brightness( image, max_delta=32. / 255., seed=seeds[0]) @@ -397,6 +445,7 @@ def _imagenet_distort_color( image, lower=0.5, upper=1.5, seed=seeds[2]) image = tf.image.stateless_random_hue( image, max_delta=0.2, seed=seeds[3]) + else: raise ValueError("color_ordering must be in {0, 1}") diff --git a/examples/lrelunet101_imagenet/experiment.py b/examples/lrelunet101_imagenet/experiment.py index e8580cf..6a69420 100644 --- a/examples/lrelunet101_imagenet/experiment.py +++ b/examples/lrelunet101_imagenet/experiment.py @@ -357,10 +357,10 @@ def lrelunet_loss( logits=logits, labels_as_int=batch["labels"], params=params, - l2_reg=l2_reg, + l2_reg=l2_reg if is_training else 0.0, haiku_exclude_batch_norm=True, haiku_exclude_biases=True, - label_smoothing=label_smoothing, + label_smoothing=label_smoothing if is_training else 0.0, average_loss=average_loss, ) diff --git a/examples/resnet50_imagenet/experiment.py b/examples/resnet50_imagenet/experiment.py index 4f30b9f..8242fd1 100644 --- a/examples/resnet50_imagenet/experiment.py +++ b/examples/resnet50_imagenet/experiment.py @@ -81,10 +81,10 @@ def resnet50_loss( logits=logits, labels_as_int=batch["labels"], params=params, - l2_reg=l2_reg, + l2_reg=l2_reg if is_training else 0.0, haiku_exclude_batch_norm=True, haiku_exclude_biases=True, - label_smoothing=label_smoothing, + label_smoothing=label_smoothing if is_training else 0.0, average_loss=average_loss, )