-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Notebook erroring out when running many iterations #928
Comments
@hawkinsp could this be related to async execution? @raphpapercup which backend is this on (CPU or GPU)? Could you try running |
If this is running on CPU, I can believe it might well be related to async execution; you could easily run OOM. We probably need to limit how far ahead the Python code can run. Could you please verify which backend you were using? Thanks! |
Currently on CPU and GPU there is no limit to how many operations the host can enqueue on the device stream. On GPU this doesn't usually cause a problem because the allocator is logically synchronized to the tail of the compute stream, and so we can free and reuse memory for operations enqueued on the stream. On CPU, the allocator is logically synchronized to the head of the compute stream, which means that the allocator cannot reuse buffers between operations enqueued on the stream. This means that the memory usage is proportional to the number of enqueued operations, which can rapidly blow up. Add a semaphore class and use it to set a moderate limit on the depth of the queue (32). The existing "synchronous" mode, used on TPU at present, is a special case of this support where the queue depth is 1. This may help with jax-ml/jax#928 . PiperOrigin-RevId: 257606960
I believe this was fixed by the stream pacing mechanism referenced above. |
When running the following notebook: https://github.com/ericjang/nf-jax/blob/c6636a010bb744a48185eb3f622d32336e929990/nf-tutorial-jax.ipynb, I have found that if you run all 1e4 iterations in a row, then evaluate y.max(), the notebook crashes. However if you run say 1000 iterations, then run y.max(), then run 1000 iterations, then run y.max(), etc, all the way to reach 1e4 iterations, the notebook does not crash and things work as expected.
The text was updated successfully, but these errors were encountered: