Skip to content

Commit

Permalink
- Modifying examples to only use label smoothing and L2 reg loss when…
Browse files Browse the repository at this point in the history
… training

- Adding extra_preprocessing_func argument to some datasets to support things like Per-location Normalization with DKS/TAT
- Updating classifier_mnist example to use classifier_loss_and_stats
- Adding whitespace to improve readability

PiperOrigin-RevId: 509883028
  • Loading branch information
james-martens authored and KfacJaxDev committed Feb 15, 2023
1 parent d6f14ad commit a1550e3
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 22 deletions.
12 changes: 7 additions & 5 deletions examples/autoencoder_mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def autoencoder_loss(
average_loss: bool = True,
) -> Tuple[chex.Array, Dict[str, chex.Array]]:
"""Evaluates the loss of the autoencoder."""
del is_training # not used

if isinstance(batch, Mapping):
batch = batch["images"]

Expand All @@ -69,16 +69,18 @@ def autoencoder_loss(
cross_entropy = jnp.sum(losses.sigmoid_cross_entropy(logits, batch), axis=-1)
averaged_cross_entropy = jnp.mean(cross_entropy)

params_l2 = losses.l2_regularizer(params, False, False)
loss = averaged_cross_entropy if average_loss else cross_entropy
regularized_loss = loss + l2_reg * params_l2

l2_reg_val = losses.l2_regularizer(params, False, False)
if is_training:
loss = loss + l2_reg * l2_reg_val

error = nn.sigmoid(logits) - batch.reshape([batch.shape[0], -1])
mean_squared_error = jnp.mean(jnp.sum(error * error, axis=1), axis=0)

return regularized_loss, dict(
return loss, dict(
cross_entropy=averaged_cross_entropy,
regualrizer=params_l2,
l2_reg_val=l2_reg_val,
mean_squared_error=mean_squared_error,
)

Expand Down
24 changes: 12 additions & 12 deletions examples/classifier_mnist/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,21 +58,21 @@ def classifier_loss(
average_loss: bool = True,
) -> Tuple[chex.Array, Dict[str, chex.Array]]:
"""Evaluates the loss of the classifier network."""
del is_training # not used

logits = convolutional_classifier().apply(params, batch["images"])

cross_entropy = losses.softmax_cross_entropy(logits, batch["labels"])

if average_loss:
cross_entropy = jnp.mean(cross_entropy)

params_l2 = losses.l2_regularizer(params, False, False)
regularized_loss = cross_entropy + l2_reg * params_l2

accuracy = losses.top_k_accuracy(logits, batch["labels"], 1)

return regularized_loss, dict(accuracy=accuracy)
loss, stats = losses.classifier_loss_and_stats(
logits=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg if is_training else 0.0,
haiku_exclude_batch_norm=False,
haiku_exclude_biases=False,
average_loss=average_loss,
top_k_stats=(1,),
)

return loss, stats


class ClassifierMnistExperiment(training.MnistExperiment):
Expand Down
51 changes: 50 additions & 1 deletion examples/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""
import types
from typing import Dict, Iterator, Mapping, Optional, Tuple, TypeVar
from typing import Callable, Dict, Iterator, Mapping, Optional, Tuple, TypeVar

import chex
import jax
Expand Down Expand Up @@ -69,6 +69,7 @@ def mnist_dataset(
Returns:
The MNIST dataset as a tensorflow dataset.
"""

# Set for multi devices vs single device
num_devices = jax.device_count() if multi_device else 1
num_local_devices = jax.local_device_count() if multi_device else 1
Expand All @@ -77,6 +78,7 @@ def mnist_dataset(
host_batch_shape = [num_local_devices, device_batch_size]
else:
host_batch_shape = [device_batch_size]

host_batch_size = num_local_devices * device_batch_size

num_examples = tfds.builder("mnist").info.splits[split].num_examples
Expand All @@ -100,16 +102,21 @@ def preprocess_batch(
return dict(images=images)

ds = tfds.load(name="mnist", split=split, as_supervised=True)

ds = ds.shard(jax.process_count(), jax.process_index())

ds = ds.cache()

if host_batch_size < num_examples and shuffle:

ds = ds.shuffle(buffer_size=(num_examples // jax.process_count()),
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration)
if repeat:
ds = ds.repeat()

ds = ds.batch(host_batch_size, drop_remainder=drop_remainder)

ds = ds.map(preprocess_batch,
num_parallel_calls=tf.data.experimental.AUTOTUNE)

Expand All @@ -122,6 +129,7 @@ def imagenet_num_examples_and_split(
split: str
) -> Tuple[int, tensorflow_datasets.Split]:
"""Returns the number of examples in the given split of Imagenet."""

if split == "train":
return 1271167, tensorflow_datasets.Split.TRAIN
elif split == "valid":
Expand All @@ -148,6 +156,9 @@ def imagenet_dataset(
dtype: jnp.dtype = jnp.float32,
image_size: chex.Shape = (224, 224),
data_dir: Optional[str] = None,
extra_preprocessing_func: Optional[
Callable[[jnp.DeviceArray, jnp.DeviceArray],
Tuple[jnp.DeviceArray, jnp.DeviceArray]]] = None,
) -> Iterator[Batch]:
"""Standard ImageNet dataset pipeline.
Expand All @@ -164,6 +175,10 @@ def imagenet_dataset(
dtype: The returned data type of the images.
image_size: The image sizes.
data_dir: If specified, will use this directory to load the dataset from.
extra_preprocessing_func: A callable to perform addition data preprocessing
if desired. Should take arguments `image` and `label` consisting of the
image and its label (without batch dimension), and return a tuple
consisting of the processed version of these two.
Returns:
The ImageNet dataset as a tensorflow dataset.
Expand Down Expand Up @@ -206,9 +221,12 @@ def imagenet_dataset(
)

if is_training:

if cache:
ds = ds.cache()

ds = ds.repeat()

if shuffle:
ds = ds.shuffle(buffer_size=buffer_size_factor * total_batch_size,
seed=shuffle_seed)
Expand Down Expand Up @@ -238,26 +256,37 @@ def preprocess(
example: Mapping[str, tf.Tensor],
seed_: Optional[tf.Tensor] = None
) -> Dict[str, tf.Tensor]:

image = _imagenet_preprocess_image(
image_bytes=example["image"],
seed=seed_,
is_training=is_training,
image_size=image_size
)
label = tf.cast(example["label"], tf.int32)

if extra_preprocessing_func is not None:
image, label = extra_preprocessing_func(image, label)

return {"images": image, "labels": label}

ds = ds.map(preprocess, num_parallel_calls=tf.data.experimental.AUTOTUNE)

def cast_fn(batch_):

tf_dtype = (tf.bfloat16 if dtype == jnp.bfloat16
else tf.dtypes.as_dtype(dtype))

batch_ = dict(**batch_)

batch_["images"] = tf.cast(batch_["images"], tf_dtype)

return batch_

for i, batch_size in enumerate(reversed(batch_dims)):

ds = ds.batch(batch_size, drop_remainder=not is_training)

if i == 0:
# NOTE: You may be tempted to move the casting earlier on in the pipeline,
# but for bf16 some operations will end up silently placed on the TPU and
Expand All @@ -276,23 +305,31 @@ def _imagenet_preprocess_image(
image_size: chex.Shape,
) -> tf.Tensor:
"""Returns processed and resized images for Imagenet."""

if is_training:
seeds = tf.random.experimental.stateless_split(seed, num=2)

# Random cropping of the image
image = _decode_and_random_crop(
image_bytes, seed=seeds[0], image_size=image_size)

# Random left-right flipping
image = tf.image.stateless_random_flip_left_right(image, seed=seeds[1])

else:
image = _decode_and_center_crop(image_bytes, image_size=image_size)

assert image.dtype == tf.uint8

# NOTE: Bicubic resize (1) casts uint8 to float32 and (2) resizes without
# clamping overshoots. This means values returned will be outside the range
# [0.0, 255.0] (e.g. we have observed outputs in the range [-51.1, 336.6]).
image = tf.image.resize(image, image_size, tf.image.ResizeMethod.BICUBIC)

# Normalize image
mean = tf.constant(_IMAGENET_MEAN_RGB, shape=[1, 1, 3], dtype=image.dtype)
std = tf.constant(_IMAGENET_STDDEV_RGB, shape=[1, 1, 3], dtype=image.dtype)

return (image - mean * 255) / (std * 255)


Expand All @@ -307,6 +344,7 @@ def _distorted_bounding_box_crop(
max_attempts: int,
) -> tf.Tensor:
"""Generates cropped_image using one of the bboxes randomly distorted for Imagenet."""

bbox_begin, bbox_size, _ = tf.image.stateless_sample_distorted_bounding_box(
image_size=jpeg_shape,
bounding_boxes=bbox,
Expand All @@ -322,6 +360,7 @@ def _distorted_bounding_box_crop(
offset_y, offset_x, _ = tf.unstack(bbox_begin)
target_height, target_width, _ = tf.unstack(bbox_size)
crop_window = tf.stack([offset_y, offset_x, target_height, target_width])

return tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)


Expand Down Expand Up @@ -354,21 +393,27 @@ def _decode_and_center_crop(
image_size: chex.Shape = (224, 224),
) -> tf.Tensor:
"""Crops to center of image with padding then scales for Imagenet."""

if jpeg_shape is None:
jpeg_shape = tf.image.extract_jpeg_shape(image_bytes)

image_height = jpeg_shape[0]
image_width = jpeg_shape[1]

# Pad the image with at least 32px on the short edge and take a
# crop that maintains aspect ratio.
scale = tf.minimum(tf.cast(image_height, tf.float32) / (image_size[0] + 32),
tf.cast(image_width, tf.float32) / (image_size[1] + 32))

padded_center_crop_height = tf.cast(scale * image_size[0], tf.int32)
padded_center_crop_width = tf.cast(scale * image_size[1], tf.int32)

offset_height = ((image_height - padded_center_crop_height) + 1) // 2
offset_width = ((image_width - padded_center_crop_width) + 1) // 2

crop_window = tf.stack([offset_height, offset_width,
padded_center_crop_height, padded_center_crop_width])

return tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)


Expand All @@ -378,7 +423,9 @@ def _imagenet_distort_color(
color_ordering: int = 0,
) -> tf.Tensor:
"""Randomly distorts colors for Imagenet."""

seeds = tf.random.experimental.stateless_split(seed, num=4)

if color_ordering == 0:
image = tf.image.stateless_random_brightness(
image, max_delta=32. / 255., seed=seeds[0])
Expand All @@ -388,6 +435,7 @@ def _imagenet_distort_color(
image, max_delta=0.2, seed=seeds[2])
image = tf.image.stateless_random_contrast(
image, lower=0.5, upper=1.5, seed=seeds[3])

elif color_ordering == 1:
image = tf.image.stateless_random_brightness(
image, max_delta=32. / 255., seed=seeds[0])
Expand All @@ -397,6 +445,7 @@ def _imagenet_distort_color(
image, lower=0.5, upper=1.5, seed=seeds[2])
image = tf.image.stateless_random_hue(
image, max_delta=0.2, seed=seeds[3])

else:
raise ValueError("color_ordering must be in {0, 1}")

Expand Down
4 changes: 2 additions & 2 deletions examples/lrelunet101_imagenet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,10 @@ def lrelunet_loss(
logits=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg,
l2_reg=l2_reg if is_training else 0.0,
haiku_exclude_batch_norm=True,
haiku_exclude_biases=True,
label_smoothing=label_smoothing,
label_smoothing=label_smoothing if is_training else 0.0,
average_loss=average_loss,
)

Expand Down
4 changes: 2 additions & 2 deletions examples/resnet50_imagenet/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ def resnet50_loss(
logits=logits,
labels_as_int=batch["labels"],
params=params,
l2_reg=l2_reg,
l2_reg=l2_reg if is_training else 0.0,
haiku_exclude_batch_norm=True,
haiku_exclude_biases=True,
label_smoothing=label_smoothing,
label_smoothing=label_smoothing if is_training else 0.0,
average_loss=average_loss,
)

Expand Down

0 comments on commit a1550e3

Please sign in to comment.