Skip to content

Commit

Permalink
* Allowing graph patterns with multiple broadcast to be merged withou…
Browse files Browse the repository at this point in the history
…t dangling equations.

* Adding pretty printing of the graph registrations.
* Graph matcher test now compares graph-to-graph topology, not their arbitrary orderings from the manually and automatically tagged function.

PiperOrigin-RevId: 438839512
  • Loading branch information
botev authored and KfacJaxDev committed Apr 1, 2022
1 parent 15eb3f7 commit 46feb7b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 56 deletions.
29 changes: 16 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,16 @@ import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax
import optax

# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 100


def make_dataset_iterator(batch_size):
# Dummy dataset, in practice this should be your dataset pipeline
while True:
for _ in range(NUM_BATCHES):
yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype="int32")


Expand All @@ -69,10 +69,10 @@ def softmax_cross_entropy(logits: jnp.ndarray, targets: jnp.ndarray):
assert logits.ndim == targets.ndim + 1

# Tell KFAC-JAX this model represents a classifier
# See https://kfac_jax.readthedocs.io/en/latest/overview.html#supported-losses
# See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
kfac_jax.register_softmax_cross_entropy_loss(logits, targets)

return optax.softmax_cross_entropy(logits, targets)
log_p = jax.nn.log_softmax(logits, axis=-1)
return - jax.vmap(lambda x, y: x[y])(log_p, targets)


def model_fn(x):
Expand Down Expand Up @@ -179,14 +179,17 @@ understood your model.
For the example above this looks like this:

```python
==================================================
Graph parameter registrations: {'mlp/~/linear_0': {'b':
'Auto[dense_with_bias_3]', 'w': 'Auto[dense_with_bias_3]'}, 'mlp/~/linear_1':
{'b': 'Auto[dense_with_bias_2]', 'w': 'Auto[dense_with_bias_2]'},
'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]', 'w':
'Auto[dense_with_bias_1]'}, 'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
'w': 'Auto[dense_with_bias_0]'}}
==================================================
==================================================
Graph parameter registrations:
{'mlp/~/linear_0': {'b': 'Auto[dense_with_bias_3]',
'w': 'Auto[dense_with_bias_3]'},
'mlp/~/linear_1': {'b': 'Auto[dense_with_bias_2]',
'w': 'Auto[dense_with_bias_2]'},
'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]',
'w': 'Auto[dense_with_bias_1]'},
'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
'w': 'Auto[dense_with_bias_0]'}}
==================================================
```

As can be seen from this message, the library has correctly detected all
Expand Down
41 changes: 22 additions & 19 deletions docs/guides.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@ optimizer:

.. code-block:: python
from typing import Tuple
import haiku as hk
import jax
import jax.numpy as jnp
import kfac_jax
import optax
# Hyper parameters
NUM_CLASSES = 10
L2_REG = 1e-3
NUM_BATCHES = 100
def make_dataset_iterator(batch_size: int):
def make_dataset_iterator(batch_size):
# Dummy dataset, in practice this should be your dataset pipeline
while True:
for _ in range(NUM_BATCHES):
yield jnp.zeros([batch_size, 100]), jnp.ones([batch_size], dtype="int32")
Expand All @@ -31,13 +30,13 @@ optimizer:
assert logits.ndim == targets.ndim + 1
# Tell KFAC-JAX this model represents a classifier
# See https://kfac_jax.readthedocs.io/en/latest/overview.html#supported-losses
# See https://kfac-jax.readthedocs.io/en/latest/overview.html#supported-losses
kfac_jax.register_softmax_cross_entropy_loss(logits, targets)
log_p = jax.nn.log_softmax(logits, axis=-1)
return - jax.vmap(lambda x, y: x[y])(log_p, targets)
return optax.softmax_cross_entropy(logits, targets)
def model_fn(x: jnp.ndarray):
def model_fn(x):
"""A Haiku MLP model function - three hidden layer network with tanh."""
return hk.nets.MLP(
output_sizes=(50, 50, 50, NUM_CLASSES),
Expand All @@ -50,16 +49,17 @@ optimizer:
hk_model = hk.without_apply_rng(hk.transform(model_fn))
def loss_fn(params: hk.Params, batch: Tuple[jnp.ndarray, jnp.ndarray]):
def loss_fn(model_params, model_batch):
"""The loss function to optimize."""
x, y = batch
logits = hk_model.apply(params, x)
x, y = model_batch
logits = hk_model.apply(model_params, x)
loss = jnp.mean(softmax_cross_entropy(logits, y))
# The optimizer assumes that the function you provide has already added
# the L2 regularizer to its gradients.
return loss + L2_REG * kfac_jax.utils.inner_product(params, params) / 2.0
# Create the optimizer
optimizer = kfac_jax.Optimizer(
value_and_grad_func=jax.value_and_grad(loss_fn),
Expand Down Expand Up @@ -147,14 +147,17 @@ by the automatic registration system, in order to ensure that it has correctly
understood your model.
For the example above this looks like this::

==================================================
Graph parameter registrations: {'mlp/~/linear_0': {'b':
'Auto[dense_with_bias_3]', 'w': 'Auto[dense_with_bias_3]'}, 'mlp/~/linear_1':
{'b': 'Auto[dense_with_bias_2]', 'w': 'Auto[dense_with_bias_2]'},
'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]', 'w':
'Auto[dense_with_bias_1]'}, 'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
'w': 'Auto[dense_with_bias_0]'}}
==================================================
==================================================
Graph parameter registrations:
{'mlp/~/linear_0': {'b': 'Auto[dense_with_bias_3]',
'w': 'Auto[dense_with_bias_3]'},
'mlp/~/linear_1': {'b': 'Auto[dense_with_bias_2]',
'w': 'Auto[dense_with_bias_2]'},
'mlp/~/linear_2': {'b': 'Auto[dense_with_bias_1]',
'w': 'Auto[dense_with_bias_1]'},
'mlp/~/linear_3': {'b': 'Auto[dense_with_bias_0]',
'w': 'Auto[dense_with_bias_0]'}}
==================================================

As can be seen from this message, the library has correctly detected all
parameters of the model to be part of dense layers.
43 changes: 35 additions & 8 deletions kfac_jax/_src/tag_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""K-FAC functionality for auto-detecting layer tags and graph matching."""
import functools
import itertools
import pprint
from typing import Any, Callable, Dict, Iterable, Iterator, List, Mapping, MutableMapping, Optional, Sequence, Set, Tuple, TypeVar, Union

from absl import logging
Expand Down Expand Up @@ -395,7 +396,8 @@ def graph(self) -> JaxprGraph:
"""The Jaxpr graph representing the computation of this pattern."""
if self._graph is None:
jnp_args = jax.tree_map(jnp.asarray, self._example_args)
self._graph = make_jax_graph(self._compute_func, jnp_args, 1, self._name)
self._graph = make_jax_graph(
broadcast_merger(self._compute_func), jnp_args, 1, self._name)
return self._graph

def tag_ctor(
Expand Down Expand Up @@ -714,6 +716,26 @@ def clean_jaxpr_eqns(
def broadcast_merger(f: utils.Func) -> utils.Func:
"""Transforms ``f`` by merging any consecutive broadcasts in its Jaxpr."""

def read_with_delayed_evaluation(env, var):
if isinstance(var, (list, tuple)):
return jax.tree_map(lambda x: read_with_delayed_evaluation(env, x), var)
elif isinstance(var, core.Literal):
# Literals are values baked into the Jaxpr
return var.val
elif isinstance(var, core.Var):
r = env[var]
if isinstance(r, (jnp.ndarray, np.ndarray)):
return r
elif isinstance(r, Callable):
y = r()
if isinstance(y, list):
assert len(y) == 1
y = y[0]
assert isinstance(y, jnp.ndarray)
env[var] = y
return y
raise NotImplementedError()

@functools.wraps(f)
def merged_func(*func_args: Any) -> Any:
typed_jaxpr, out_avals = jax.make_jaxpr(f, return_shape=True)(*func_args)
Expand All @@ -722,7 +744,7 @@ def merged_func(*func_args: Any) -> Any:

# Mapping from variable -> value
env = {}
read = functools.partial(read_env, env)
read = functools.partial(read_with_delayed_evaluation, env)
write = functools.partial(write_env, env)

# Bind args and consts to environment
Expand All @@ -746,13 +768,17 @@ def merged_func(*func_args: Any) -> Any:
x, dims = broadcasts_outputs[eqn.invars[0]]
kept_dims = eqn.params["broadcast_dimensions"]
kept_dims = [kept_dims[d] for d in dims]
y = lax.broadcast_in_dim(x, eqn.params["shape"], kept_dims)
write(eqn.outvars, [y])
# In order not to compute any un-needed broadcasting we instead put
# in a function for delayed evaluation.
write(eqn.outvars, [functools.partial(
lax.broadcast_in_dim, x, eqn.params["shape"], kept_dims)])
broadcasts_outputs[eqn.outvars[0]] = (x, kept_dims)
else:
input_values = read(eqn.invars)
out_values = eval_jaxpr_eqn(eqn, input_values)
write(eqn.outvars, out_values)
# In order not to compute any un-needed broadcasting we instead put
# in a function for delayed evaluation.
write(eqn.outvars, [functools.partial(
eval_jaxpr_eqn, eqn, input_values)])
broadcasts_outputs[eqn.outvars[0]] = (
(input_values[0], eqn.params["broadcast_dimensions"]))
else:
Expand Down Expand Up @@ -1111,8 +1137,9 @@ def auto_register_tags(

params_labels = [tagged_params.get(p, "Orphan") for p in graph.params_vars]
logging.info("=" * 50)
logging.info("Graph parameter registrations: %s",
str(jax.tree_unflatten(graph.params_tree, params_labels)))
logging.info("Graph parameter registrations:")
logging.info(pprint.pformat(
jax.tree_unflatten(graph.params_tree, params_labels)))
logging.info("=" * 50)

# Construct a function with all of the extra tag registrations
Expand Down
79 changes: 63 additions & 16 deletions tests/test_graph_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,34 @@
import kfac_jax
from tests import models

tag_graph_matcher = kfac_jax.tag_graph_matcher


class TestGraphMatcher(parameterized.TestCase):
"""Test class for the functions in `tag_graph_matcher.py`."""

def check_equation_match(self, eqn1, vars_to_vars, vars_to_eqn):
"""Checks that equation is matched in the other graph."""
eqn1_out_vars = [v for v in eqn1.outvars
if not isinstance(v, (jax.core.UnitVar, jax.core.DropVar))]
eqn2_out_vars = [vars_to_vars[v] for v in eqn1_out_vars]
eqns = [vars_to_eqn[v] for v in eqn2_out_vars]
self.assertTrue(all(e == eqns[0] for e in eqns[1:]))
eqn2 = eqns[0]

self.assertEqual(eqn1.primitive, eqn2.primitive)
# For xla_call we skip any detailed check as they are very complicated.
if eqn1.primitive.name != "xla_call":
for k in eqn1.params:
self.assertEqual(eqn1.params[k], eqn2.params[k])
for v1, v2 in zip(eqn1.invars, eqn2.invars):
if isinstance(v1, jax.core.Literal):
self.assertIsInstance(v2, jax.core.Literal)
self.assertEqual(v1.aval, v2.aval)
else:
self.assertEqual(v1.aval.shape, v2.aval.shape)
self.assertEqual(v1.aval.dtype, v2.aval.dtype)
vars_to_vars[v1] = v2
return vars_to_vars

@parameterized.parameters(models.NON_LINEAR_MODELS)
def test_auto_register_tags_jaxpr(
self,
Expand All @@ -50,7 +72,8 @@ def test_auto_register_tags_jaxpr(
data[name] = jnp.argmax(data[name], axis=-1)

params = init_func(init_key, data)
func = tag_graph_matcher.auto_register_tags(model_func, (params, data))
func = kfac_jax.tag_graph_matcher.auto_register_tags(
model_func, (params, data))
jaxpr = jax.make_jaxpr(func)(params, data).jaxpr
tagged_func = functools.partial(
model_func,
Expand All @@ -65,20 +88,44 @@ def test_auto_register_tags_jaxpr(
# it will always have the same or less number of equations.
self.assertLessEqual(len(jaxpr.eqns), len(tagged_jaxpr.eqns))

for eq, tagged_eq in zip(jaxpr.eqns, tagged_jaxpr.eqns):
eq_in_vars = [v for v in eq.invars
if not isinstance(v, jax.core.UnitVar)]
tagged_in_vars = [v for v in tagged_eq.invars
if not isinstance(v, jax.core.UnitVar)]
self.assertEqual(len(eq_in_vars), len(tagged_in_vars))
self.assertEqual(len(eq.outvars), len(tagged_eq.outvars))
self.assertEqual(eq.primitive, tagged_eq.primitive)
for variable, t_variable in zip(eq_in_vars + eq.outvars,
tagged_in_vars + tagged_eq.outvars):
if isinstance(variable, jax.core.Literal):
self.assertEqual(variable.aval, t_variable.aval)
# Extract all loss tags from both jax expressions
l1_eqns = []
for eqn in jaxpr.eqns:
if isinstance(eqn.primitive, kfac_jax.layers_and_loss_tags.LossTag):
l1_eqns.append(eqn)
l2_eqns = []
vars_to_eqn = {}
for eqn in tagged_jaxpr.eqns:
if isinstance(eqn.primitive, kfac_jax.layers_and_loss_tags.LossTag):
l2_eqns.append(eqn)
for v in eqn.outvars:
vars_to_eqn[v] = eqn
self.assertEqual(len(l1_eqns), len(l2_eqns))

# Match all losses output variables
vars_to_vars = {}
for eqn1, eqn2 in zip(l1_eqns, l2_eqns):
self.assertEqual(len(eqn1.outvars), len(eqn2.outvars))
for v1, v2 in zip(eqn1.outvars, eqn2.outvars):
if isinstance(v1, jax.core.DropVar):
self.assertIsInstance(v2, jax.core.DropVar)
elif isinstance(v1, jax.core.Literal):
self.assertIsInstance(v2, jax.core.Literal)
self.assertEqual(v1.aval, v2.aval)
else:
self.assertEqual(variable.count, t_variable.count)
self.assertEqual(v1.aval.shape, v2.aval.shape)
self.assertEqual(v1.aval.dtype, v2.aval.dtype)
vars_to_vars[v1] = v2

# Match all other equations
for eqn in reversed(jaxpr.eqns):
vars_to_vars = self.check_equation_match(eqn, vars_to_vars, vars_to_eqn)

for v1 in jaxpr.invars:
v2 = vars_to_vars[v1]
self.assertEqual(v1.aval.shape, v2.aval.shape)
self.assertEqual(v1.aval.dtype, v2.aval.dtype)
self.assertEqual(v1.count, v2.count)


if __name__ == "__main__":
Expand Down

0 comments on commit 46feb7b

Please sign in to comment.