Skip to content

Debugging tool to print information (especially sharding) about jax arrays

License

Notifications You must be signed in to change notification settings

Findus23/jax-array-info

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

16 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

jax-array-info

This package contains two functions for debugging jax Arrays:

pip install git+https://github.com/Findus23/jax-array-info.git
from jax_array_info import sharding_info, sharding_vis, print_array_stats

sharding_info(arr)

sharding_info(arr) prints general information about a jax or numpy array with special focus on sharding ( supporting SingleDeviceSharding, GSPMDSharding, PositionalSharding, NamedSharding and PmapSharding)

array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_info(array, "some_array")
╭────────────────── some_array ────────────────╮
│ shape: (128, 128, 128)                       │
│ dtype: float32                               │
│ size: 8.0 MiB                                │
│ NamedSharding: P(None, 'gpus')               │
│ axis 1 is sharded: CPU 0 contains 0:16 (1/8) │
│                    Total size: 128           │
╰──────────────────────────────────────────────╯

sharding_vis(arr)

A modified version of jax.debug.visualize_array_sharding() that also supports arrays with more than 2 dimensions (by ignoring non-sharded dimensions in the visualisation until reaching 2 dimensions)

array = jax.numpy.zeros(shape=(N, N, N), dtype=jax.numpy.float32)
array = jax.device_put(array, NamedSharding(mesh, P(None, "gpus")))
sharding_vis(array)
─────────── showing dims [0, 1] from original shape (128, 128, 128) ────────────
┌───────┬───────┬───────┬───────┬───────┬───────┬───────┬───────┐
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│ CPU 0 │ CPU 1 │ CPU 2 │ CPU 3 │ CPU 4 │ CPU 5 │ CPU 6 │ CPU 7 │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
│       │       │       │       │       │       │       │       │
└───────┴───────┴───────┴───────┴───────┴───────┴───────┴───────┘

print_array_stats()

Shows a nice overview over the all currently allocated arrays ordered by size.

Disclaimer: This uses jax.live_arrays() to get its information. There might be allocated arrays that are missing in this view. Also

arr = jax.numpy.zeros(shape=(16, 16, 16))
arr2 = jax.device_put(jax.numpy.zeros(shape=(2, 16, 4)), NamedSharding(mesh, P(None, "gpus")))

print_array_stats()
             allocated jax arrays              
┏━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━┓
┃ size     ┃ shape        ┃      sharded      ┃
┡━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━┩
│ 16.0 KiB │ (16, 16, 16) │                   │
│ 64.0 B   │ (2, 16, 4)   │ ✔ (512.0 B total) │
├──────────┼──────────────┼───────────────────┤
│ 16.1 KiB │              │                   │
└──────────┴──────────────┴───────────────────┘

Examples

See tests/

About

Debugging tool to print information (especially sharding) about jax arrays

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages