Examples of computation graphs visualised using
jaxpr-viz
.
The setting of collapse_primitives
used is also included
(see here for an explanation).
@jax.jit
def func1(first, second):
temp = first + jnp.sin(second) * 3.0
return jnp.sum(temp)
collapse_primitives=False
@jax.jit
def one_of_three(index, arg):
return jax.lax.switch(
index,
[lambda x: x + 1.0, lambda x: x - 2.0, lambda x: x + 3.0],
arg
)
collapse_primitives=False
@jax.jit
def func7(arg):
return jax.lax.cond(
arg >= 0.0,
lambda x_true: x_true + 3.0,
lambda x_false: x_false - 3.0,
arg
)
collapse_primitives=True
@jax.jit
def func8(arg1, arg2):
return jax.lax.cond(
arg1 >= 0.0,
lambda x_true: x_true[0],
lambda x_false: jnp.array([1]) + x_false[1],
arg2,
)
collapse_primitives=True
@jax.jit
def func10(arg, n):
ones = jnp.ones(arg.shape)
return jax.lax.fori_loop(
0, n, lambda i, carry: carry + ones * 3.0 + arg, arg + ones
)
collapse_primitives=False
@jax.jit
def func11(arr, extra):
ones = jnp.ones(arr.shape)
def body(carry, a_elems):
ae1, ae2 = a_elems
return carry + ae1 * ae2 + extra, carry
return jax.lax.scan(body, 0.0, (arr, ones))