JAX Computation Graph Visualisation Tool
JAX has built-in functionality to visualise the HLO graph generated by JAX, but I've found this rather low-level for some use-cases.
The intention of this package is to visualise how sub-functions are connected in JAX programs. It does this by converting the JaxPr representation into a pydot graph. See here for examples.
NOTE: This project is still at an early stage and may not support all JAX functionality (or permutations thereof). If you spot some strange behaviour please create a Github issue.
Install with pip:
pip install jpviz
Dependent on your system you may also need to install Graphviz
Jaxpr-viz can be used to visualise jit compiled (and nested) functions. It wraps jit compiled functions, which when called with concrete values returns a pydot graph.
For example this simple computation graph
import jax
import jax.numpy as jnp
import jpviz
@jax.jit
def foo(x):
return 2 * x
@jax.jit
def bar(x):
x = foo(x)
return x - 1
# Wrap function and call with concrete arguments
# here dot_graph is a pydot object
dot_graph = jpviz.draw(bar)(jnp.arange(10))
# This renders the graph to a png file
dot_graph.write_png("computation_graph.png")
produces this image
Pydot has a number of options for rendering graphs, see here.
NOTE: For sub-functions to show as nodes/sub-graphs they need to be marked with
@jax.jit
, otherwise they will just merged into thir parent graph.
To show the rendered graph in a jupyter notebook you can use the
helper function view_pydot
...
dot_graph = jpviz.draw(bar)(jnp.arange(10))
jpviz.view_pydot(dot)
By default, functions that are composed of only primitive functions
are collapsed into a single node (like foo
in the above example).
The full computation graph can be rendered using the collapse_primitives
flag, setting it to False
in the above example
...
dot_graph = jpviz.draw(bar, collapse_primitives=False)(jnp.arange(10))
...
produces
By default, type information is included in the node labels, this
can be hidden using the show_avals
flag, setting it to False
...
dot_graph = jpviz.draw(bar, show_avals=False)(jnp.arange(10))
...
produces
NOTE: The labels of the nodes don't currently correspond to argument/variable names in the original Python code. Since JAX unpacks arguments/outputs to tuples they do correspond to the positioning of arguments and outputs.
See here for more examples of rendered computation graphs.
Developer notes can be found here.