Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid a NumPy bug, triggered by the JAX change #701

Merged
merged 1 commit into from
Dec 2, 2020

Conversation

copybara-service[bot]
Copy link

Avoid a NumPy bug, triggered by the JAX change
jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.iter should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray of
length 10 or less
causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

@codecov-io
Copy link

codecov-io commented Dec 2, 2020

Codecov Report

Merging #701 (5bdce20) into master (0bed2f6) will not change coverage.
The diff coverage is n/a.

Impacted file tree graph

@@           Coverage Diff           @@
##           master     #701   +/-   ##
=======================================
  Coverage   80.66%   80.66%           
=======================================
  Files          56       56           
  Lines        4267     4267           
=======================================
  Hits         3442     3442           
  Misses        825      825           

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 0bed2f6...5bdce20. Read the comment docs.

jax-ml/jax#3821.

The idea of the JAX change is in part that DeviceArray.__iter__ should return
DeviceArrays. Before #3821, it returned numpy.ndarrays. One main motivation is
performance: it avoids a host sync. A secondary motivation is type consistency.

However, that caused this line of Flax example code to trigger a NumPy bug,
discussed in this thread:
jax-ml/jax#620 (comment)
Basically, x[i] where x is a numpy.ndarray and i is a JAX DeviceArray _of
length 10 or less_ causes NumPy to interperet i as a non-array sequence (e.g. a
tuple) rather than as an array, leading to an error like "IndexError: too many
indices for array". The workaround employed here is to write x[i, ...] instead
of x[i], which bypasses the NumPy bug.

PiperOrigin-RevId: 345160314
@copybara-service copybara-service bot merged commit 69b9d44 into master Dec 2, 2020
@copybara-service copybara-service bot deleted the test_345140147 branch December 2, 2020 05:17
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants