You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I need to run several independent computation, each of which runs faster on the GPU than on the CPU, but all involve some data transfer between CPU and GPU, and even if kernels are running they are often still rather small.
If I were to program this directly in CUDA, I would put these computations on different CUDA streams, and run them concurrently, so that the kernel from one job can run while the data from a different job is transferred.
Is there any way I can get some control over this in jax? So for instance run each job in a different job, assign and use a different stream on each of those threads? In pytorch I can use something like this, but I couldn't find any documentation for jax, or even a discussion of it so far. This makes me wonder if I just misunderstand something about how jax works internally, and if so I'd also be very curious what that is.
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
I need to run several independent computation, each of which runs faster on the GPU than on the CPU, but all involve some data transfer between CPU and GPU, and even if kernels are running they are often still rather small.
If I were to program this directly in CUDA, I would put these computations on different CUDA streams, and run them concurrently, so that the kernel from one job can run while the data from a different job is transferred.
Is there any way I can get some control over this in jax? So for instance run each job in a different job, assign and use a different stream on each of those threads? In pytorch I can use something like this, but I couldn't find any documentation for jax, or even a discussion of it so far. This makes me wonder if I just misunderstand something about how jax works internally, and if so I'd also be very curious what that is.
Beta Was this translation helpful? Give feedback.
All reactions