Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Avoid a NumPy bug, triggered by the JAX change
jax-ml/jax#3821. The idea of the JAX change is in part that DeviceArray.__iter__ should return DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is performance: it avoids a host sync. A secondary motivation is type consistency. However, that caused this line of Flax example code to trigger a NumPy bug, discussed in this thread: jax-ml/jax#620 (comment) Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a tuple) rather than as an array, leading to an error like "IndexError: too many indices for array". The workaround employed here is to write x[i, ...] instead of x[i], which bypasses the NumPy bug. PiperOrigin-RevId: 345140147
- Loading branch information