You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I'm currently porting over some old pre-jax.Arraypjit data and tensor-parallel code to work with newer JAX versions. I read this section from the Array migration guide and for data batches I can wrap the batch in host_local_array_to_global_array to recover correct global mesh sharding, as previously pjit only required the batch to be sharded over the submesh local to the host.
However, if I do this and call jax.debug.visualize_array_sharding(mybatch) on the first host, I get the following visualization (assuming 4-way DP and 8-way TP):
To me this signifies that this data batch has now been split across hosts, even though it comes from the first host. Does this mean that data-loading should now only happen on a single host and the batch will then be sharded across the other hosts or am I misunderstanding the above visualization?
reacted with thumbs up emoji reacted with thumbs down emoji reacted with laugh emoji reacted with hooray emoji reacted with confused emoji reacted with heart emoji reacted with rocket emoji reacted with eyes emoji
-
Hi,
I'm currently porting over some old pre-
jax.Array
pjit
data and tensor-parallel code to work with newer JAX versions. I read this section from the Array migration guide and for data batches I can wrap the batch inhost_local_array_to_global_array
to recover correct global mesh sharding, as previouslypjit
only required the batch to be sharded over the submesh local to the host.However, if I do this and call
jax.debug.visualize_array_sharding(mybatch)
on the first host, I get the following visualization (assuming 4-way DP and 8-way TP):To me this signifies that this data batch has now been split across hosts, even though it comes from the first host. Does this mean that data-loading should now only happen on a single host and the batch will then be sharded across the other hosts or am I misunderstanding the above visualization?
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions