Explanation BufferAssignment OOM Debugging #19496
Unanswered
JohannesEsslinger
asked this question in
General
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Hi!
My question relates to OOM memory debugging in Jax. Unfortunately there is not much information in the documentation or other sources. Unfortunately similar thread, which can be found here has not yet been addressed.
I have the problem that i get the following OOM expection, but don't know why 15! buffers (for the same function call) are allocated.
The function in source_line 200 is:
whereby per_step_input is a tuple (dict, ArrayImpl:(11, 2), ArrayImpl(11,)).
I have to mention that 11 is the number of iterations, 2000 batch size, and 1000 the dimension of a single sample. (Dunno where the 9 is coming from)
Do you have any guess what could cause Jax to create 15 buffer or have any suggestions to debug OOM?
Kind regards!
Beta Was this translation helpful? Give feedback.
All reactions