diff --git a/linen_examples/mnist/train.py b/linen_examples/mnist/train.py index 9def11cb1f..be7e4b1ecb 100644 --- a/linen_examples/mnist/train.py +++ b/linen_examples/mnist/train.py @@ -112,7 +112,7 @@ def train_epoch(optimizer, train_ds, batch_size, epoch, rng): perms = perms.reshape((steps_per_epoch, batch_size)) batch_metrics = [] for perm in perms: - batch = {k: v[perm] for k, v in train_ds.items()} + batch = {k: v[perm, ...] for k, v in train_ds.items()} optimizer, metrics = train_step(optimizer, batch) batch_metrics.append(metrics)