Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Notebook erroring out when running many iterations #928

Closed
raphpapercup opened this issue Jun 26, 2019 · 3 comments
Closed

Notebook erroring out when running many iterations #928

raphpapercup opened this issue Jun 26, 2019 · 3 comments
Labels
bug Something isn't working question Questions for the JAX team

Comments

@raphpapercup
Copy link

When running the following notebook: https://github.com/ericjang/nf-jax/blob/c6636a010bb744a48185eb3f622d32336e929990/nf-tutorial-jax.ipynb, I have found that if you run all 1e4 iterations in a row, then evaluate y.max(), the notebook crashes. However if you run say 1000 iterations, then run y.max(), then run 1000 iterations, then run y.max(), etc, all the way to reach 1e4 iterations, the notebook does not crash and things work as expected.

@mattjj
Copy link
Collaborator

mattjj commented Jun 26, 2019

@hawkinsp could this be related to async execution?

@raphpapercup which backend is this on (CPU or GPU)? Could you try running y.block_until_ready() every 1000 iterations, instead of y.max(), just to help us diagnose the bug? See the async dispatch docs for an explanation of why it might be related.

@hawkinsp
Copy link
Collaborator

If this is running on CPU, I can believe it might well be related to async execution; you could easily run OOM. We probably need to limit how far ahead the Python code can run. Could you please verify which backend you were using? Thanks!

@hawkinsp hawkinsp added bug Something isn't working question Questions for the JAX team labels Jun 26, 2019
mahak pushed a commit to mahak/tensorflow that referenced this issue Jul 11, 2019
Currently on CPU and GPU there is no limit to how many operations the host can enqueue on the device stream.

On GPU this doesn't usually cause a problem because the allocator is logically synchronized to the tail of the compute stream, and so we can free and reuse memory for operations enqueued on the stream.

On CPU, the allocator is logically synchronized to the head of the compute stream, which means that the allocator cannot reuse buffers between operations enqueued on the stream. This means that the memory usage is proportional to the number of enqueued operations, which can rapidly blow up.

Add a semaphore class and use it to set a moderate limit on the depth of the queue (32). The existing "synchronous" mode, used on TPU at present, is a special case of this support where the queue depth is 1.

This may help with jax-ml/jax#928 .

PiperOrigin-RevId: 257606960
@hawkinsp
Copy link
Collaborator

I believe this was fixed by the stream pacing mechanism referenced above.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working question Questions for the JAX team
Projects
None yet
Development

No branches or pull requests

3 participants