Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor updates to How to think in JAX and Working with pytrees #18968

Merged
merged 1 commit into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 15 additions & 10 deletions docs/tutorials/thinking-in-jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp);
```

The code blocks are identical aside from replacing `np` with `jnp`, and the results are the same. JAX arrays can often be used directly in place of NumPy arrays for things like plotting.
The code blocks are identical aside from replacing NumPy (`np`) with JAX NumPy (`jnp`), and the results are the same. JAX arrays can often be used directly in place of NumPy arrays for things like plotting.

The arrays themselves are implemented as different Python types:

Expand Down Expand Up @@ -99,7 +99,7 @@ print(y)

- `jax.Array` is the default array implementation in JAX.
- The JAX array is a unified distributed datatype for representing arrays, even with physical storage spanning multiple devices
- Automatic parallelization: You can operate over sharded `jax.Array`s without copying data onto a device using the `jax.jit` transformation. You can also replicate a `jax.Array` to every device on a mesh.
- Automatic parallelization: You can operate over sharded `jax.Array`s without copying data onto a device using the {func}`jax.jit` transformation. You can also replicate a `jax.Array` to every device on a mesh.

Consider this simple example:

Expand Down Expand Up @@ -127,7 +127,7 @@ The `jax.Array` type also helps make parallelism a core feature of JAX.

JAX has built-in support for objects that look like dictionaries (dicts) of arrays, or lists of lists of dicts, or other nested structures — they are called JAX pytrees (also known as nests, or just trees). In the context of machine learning, a pytree can contain model parameters, dataset entries, and reinforcement learning agent observations.

Below is an example of a simple pytree. In JAX, you can use `jax.tree_*`, to extract the flattened leaves from the trees, as demonstrated here:
Below is an example of a simple pytree. In JAX, you can use {func}`jax.tree_util.tree_leaves`, to extract the flattened leaves from the trees, as demonstrated here:

```{code-cell}
example_trees = [
Expand All @@ -153,8 +153,8 @@ You can learn more in the {ref}`working-with-pytrees` tutorial.

**Key concepts:**

- `jax.numpy` is a high-level wrapper that provides a familiar interface.
- `jax.lax` is a lower-level API that is stricter and often more powerful.
- {mod}`jax.numpy` is a high-level wrapper that provides a familiar interface.
- {mod}`jax.lax` is a lower-level API that is stricter and often more powerful.
- All JAX operations are implemented in terms of operations in [XLA](https://www.tensorflow.org/xla/) — the Accelerated Linear Algebra compiler.

If you look at the source of {mod}`jax.numpy`, you'll see that all the operations are eventually expressed in terms of functions defined in {mod}`jax.lax`. You can think of {mod}`jax.lax` as a stricter, but often more powerful, API for working with multi-dimensional arrays.
Expand Down Expand Up @@ -218,7 +218,7 @@ Every JAX operation is eventually expressed in terms of these fundamental XLA op

The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.

For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of `jax.numpy` operations:
For example, consider this function that normalizes the rows of a 2D matrix, expressed in terms of {mod}`jax.numpy` operations:

```{code-cell}
import jax.numpy as jnp
Expand Down Expand Up @@ -281,7 +281,7 @@ This is because the function generates an array whose shape is not known at comp

- Variables that you don't want to be traced can be marked as *static*

To use `jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function:
To use {func}`jax.jit` effectively, it is useful to understand how it works. Let's put a few `print()` statements within a JIT-compiled function and then call the function:

```{code-cell}
@jit
Expand All @@ -300,7 +300,7 @@ f(x, y)

Notice that the print statements execute, but rather than printing the data you passed to the function, though, it prints *tracer* objects that stand-in for them.

These tracer objects are what `jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.
These tracer objects are what {func}`jax.jit` uses to extract the sequence of operations specified by the function. Basic tracers are stand-ins that encode the **shape** and **dtype** of the arrays, but are agnostic to the values. This recorded sequence of computations can then be efficiently applied within XLA to new inputs with the same shape and dtype, without having to re-execute the Python code.

When you call the compiled function again on matching inputs, no recompilation is required and nothing is printed because the result is computed in compiled XLA rather than in Python:

Expand All @@ -310,7 +310,7 @@ y2 = np.random.randn(4)
f(x2, y2)
```

The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the `jax.make_jaxpr` transformation:
The extracted sequence of operations is encoded in a JAX expression, or *jaxpr* for short. You can view the jaxpr using the {func}`jax.make_jaxpr` transformation:

```python
from jax import make_jaxpr
Expand Down Expand Up @@ -395,7 +395,12 @@ f(x)

Notice that although `x` is traced, `x.shape` is a static value. However, when you use {func}`jnp.array` and {func}`jnp.prod` on this static value, it becomes a traced value, at which point it cannot be used in a function like `reshape()` that requires a static input (recall: array shapes must be static).

A useful pattern is to use `numpy` for operations that should be static (i.e. done at compile-time), and use `jax.numpy` for operations that should be traced (i.e. compiled and executed at run-time). For this function, it might look like this:
A useful pattern is to:

- Use NumPy (`numpy`) for operations that should be static (i.e., done at compile-time); and
- Use JAX NumPy (`jax.numpy`) for operations that should be traced (i.e. compiled and executed at run-time).

For this function, it might look like this:

```{code-cell}
from jax import jit
Expand Down
22 changes: 11 additions & 11 deletions docs/tutorials/working-with-pytrees.md
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def update(params, x, y):
(pytrees-custom-pytree-nodes)=
## Custom pytree nodes

This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {meth}`jax.tree_util.register_pytree_node` with {func}`jax.tree_map`.
This section explains how in JAX you can extend the set of Python types that will be considered _internal nodes_ in pytrees (pytree nodes) by using {func}`jax.tree_util.register_pytree_node` with {func}`jax.tree_map`.

Why would you need this? In the previous examples, pytrees were shown as lists, tuples, and dicts, with everything else as pytree leaves. This is because if you define your own container class, it will be considered to be a pytree leaf unless you _register_ it with JAX. This is also the case even if your container class has trees inside it. For example:

Expand Down Expand Up @@ -186,7 +186,7 @@ except TypeError as e:

As a solution, JAX allows to extend the set of types to be considered internal pytree nodes through a global registry of types. Additionally, the values of registered types are traversed recursively.

First, register a new type using {meth}`jax.tree_util.register_pytree_node`:
First, register a new type using {func}`jax.tree_util.register_pytree_node`:

```{code-cell}
from jax.tree_util import register_pytree_node
Expand Down Expand Up @@ -269,11 +269,11 @@ Notice that the `name` field now appears as a leaf, because all tuple elements a
(pytree-and-jax-transformations)=
## Pytree and JAX's transformations

Many JAX functions, like {meth}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.
Many JAX functions, like {func}`jax.lax.scan`, operate over pytrees of arrays. In addition, all JAX function transformations can be applied to functions that accept as input and produce as output pytrees of arrays.

Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax,vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.
Some JAX function transformations take optional parameters that specify how certain input or output values should be treated (such as the `in_axes` and `out_axes` arguments to {func}`jax.vmap`). These parameters can also be pytrees, and their structure must correspond to the pytree structure of the corresponding arguments. In particular, to be able to “match up” leaves in these parameter pytrees with values in the argument pytrees, the parameter pytrees are often constrained to be tree prefixes of the argument pytrees.

For example, if you pass the following input to {func}`jax,vmap` (note that the input arguments to a function are considered a tuple):
For example, if you pass the following input to {func}`jax.vmap` (note that the input arguments to a function are considered a tuple):

```
(a1, {"k1": a2, "k2": a3})
Expand All @@ -287,7 +287,7 @@ then you can use the following `in_axes` pytree to specify that only the `k2` ar

The optional parameter pytree structure must match that of the main input pytree. However, the optional parameters can optionally be specified as a “prefix” pytree, meaning that a single leaf value can be applied to an entire sub-pytree.

For example, if you have the same {func}`jax,vmap` input as above, but wish to only map over the dictionary argument, you can use:
For example, if you have the same {func}`jax.vmap` input as above, but wish to only map over the dictionary argument, you can use:

```
(None, 0) # equivalent to (None, {"k1": 0, "k2": 0})
Expand All @@ -299,7 +299,7 @@ Alternatively, if you want every argument to be mapped, you can write a single l
0
```

This happens to be the default `in_axes` value for {func}`jax,vmap`.
This happens to be the default `in_axes` value for {func}`jax.vmap`.

The same logic applies to other optional parameters that refer to specific input or output values of a transformed function, such as `out_axes` in {func}`jax.vmap`.

Expand All @@ -312,9 +312,9 @@ For built-in pytree node types, the set of keys for any pytree node instance is

JAX has the following `jax.tree_util.*` methods for working with key paths:

- {meth}`jax.tree_util.tree_flatten_with_path`: Works similarly to {meth}`jax.tree_util.tree_flatten`, but returns key paths.
- {meth}`jax.tree_util.tree_map_with_path``: Works similarly to {meth}`jax.tree_util.tree_map`, but the function also takes key paths as arguments.
- {meth}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression.
- {func}`jax.tree_util.tree_flatten_with_path`: Works similarly to {func}`jax.tree_util.tree_flatten`, but returns key paths.
- {func}`jax.tree_util.tree_map_with_path``: Works similarly to {func}`jax.tree_util.tree_map`, but the function also takes key paths as arguments.
- {func}`jax.tree_util.keystr`: Given a general key path, returns a reader-friendly string expression.

For example, one use case is to print debugging information related to a certain leaf value:

Expand All @@ -336,7 +336,7 @@ To express key paths, JAX provides a few default key types for the built-in pytr
* `DictKey(key: Hashable)`: For dictionaries.
* `GetAttrKey(name: str)`: For `namedtuple`s and preferably custom pytree nodes (more in the next section)

You are free to define your own key types for your custom nodes. They will work with {meth}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression.
You are free to define your own key types for your custom nodes. They will work with {func}`jax.tree_util.keystr` as long as their `__str__()` method is also overridden with a reader-friendly expression.

```{code-cell}
for key_path, _ in flattened:
Expand Down