Skip to content

Commit

Permalink
Avoid a NumPy bug, triggered by the JAX change
Browse files Browse the repository at this point in the history
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345140147
  • Loading branch information
mattjj authored and Flax Authors committed Dec 2, 2020
1 parent 0bed2f6 commit 5bdce20
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion linen_examples/mnist/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def train_epoch(optimizer, train_ds, batch_size, epoch, rng):
perms = perms.reshape((steps_per_epoch, batch_size))
batch_metrics = []
for perm in perms:
batch = {k: v[perm] for k, v in train_ds.items()}
batch = {k: v[perm, ...] for k, v in train_ds.items()}
optimizer, metrics = train_step(optimizer, batch)
batch_metrics.append(metrics)

Expand Down

0 comments on commit 5bdce20

Please sign in to comment.