Sharding Visualisation/Debugging #17641
Unanswered
Findus23
asked this question in
Show and tell
Replies: 0 comments
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Recently I have been working a lot with sharding in JAX and trying to wrap my head around how it can be implemented in a more complex existing codebasis.
I really like the visualisation of
jax.debug.visualize_array_sharding()
, but it has the major limitation that it only supports arrays with at most two dimensions. As in my case every array has a shape of at least 3 dimensions, I can't use it.So I wrote a fork of this function that works around this by ignoring non-sharded dimensions until a 2D visualisation is possible.
And while thinking about it, I realized that in higher dimensions a simple text representation feels more intuitive to me than the table-view. Therefore I created another function, that works similar, but instead shows general information about the array (shape, dtype, nbytes), but also supports all sharding types and prints the axis along which the array is sharded.
Keep in mind that the code is a bit experimental and might include some sharding misconceptions of mine, but it already helped me a lot to understand how the array sharding propagates through the calculations.
https://github.com/Findus23/jax-array-info/
You can find a few examples of the possible output here:
https://github.com/Findus23/jax-array-info/blob/main/tests/jaxtest.py
Beta Was this translation helpful? Give feedback.
All reactions