Skip to content

Commit

Permalink
[attrs] add linearize and vjp support
Browse files Browse the repository at this point in the history
  • Loading branch information
mattjj committed Feb 24, 2024
1 parent 67572d3 commit b0b88d8
Show file tree
Hide file tree
Showing 3 changed files with 256 additions and 9 deletions.
5 changes: 2 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2084,9 +2084,8 @@ def linearize(fun: Callable, *primals, has_aux: bool = False
jaxtree_fun, out_tree = flatten_fun_nokwargs2(f, in_tree)
else:
jaxtree_fun, out_tree = flatten_fun_nokwargs(f, in_tree)
out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize(jaxtree_fun,
*primals_flat,
has_aux=has_aux)
out_primals, out_pvals, jaxpr, consts, *maybe_aux = ad.linearize(
jaxtree_fun, *primals_flat, has_aux=has_aux)
if has_aux:
out_tree, aux_tree = out_tree()
else:
Expand Down
92 changes: 87 additions & 5 deletions jax/experimental/attrs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,12 @@
from jax._src.api_util import flatten_fun_nokwargs
from jax._src.interpreters import ad
from jax._src.interpreters import partial_eval as pe
from jax._src.tree_util import tree_flatten, tree_unflatten
from jax._src.util import unzip2
from jax._src.tree_util import (tree_flatten, tree_unflatten, tree_structure,
treedef_tuple)
from jax._src.util import unzip2, safe_map, safe_zip, split_list

map, unsafe_map = safe_map, map
zip, unsafe_zip = safe_zip, zip

JaxVal = Any

Expand Down Expand Up @@ -84,8 +88,8 @@ def _setattr_staging(trace, tracer, *, obj, attr):
def jvp(f, primals, tangents, attr_tangents):
attrs, attr_tangents = unzip2(((o, a), t) for o, a, t in attr_tangents)
attr_primals = tuple(jax_getattr(o, a) for o, a in attrs)
primals_flat, in_tree = tree_flatten((attr_primals, primals))
tangents_flat, in_tree_ = tree_flatten((attr_tangents, tangents))
primals_flat, in_tree = tree_flatten((attr_primals, *primals))
tangents_flat, in_tree_ = tree_flatten((attr_tangents, *tangents))
if in_tree != in_tree_: raise Exception
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), in_tree)
out_primals_flat, out_tangents_flat, tangent_attrs_out = _jvp(f_).call_wrapped(
Expand All @@ -95,7 +99,7 @@ def jvp(f, primals, tangents, attr_tangents):
return out_primals, out_tangents, tangent_attrs_out

@lu.transformation
def _set_attrs(attrs, attr_vals, args):
def _set_attrs(attrs, attr_vals, *args):
for (o, a), x in zip(attrs, attr_vals):
jax_setattr(o, a, x)
yield (yield args, {})
Expand Down Expand Up @@ -134,3 +138,81 @@ def _setattr_jvp(trace, tracer, *, obj, attr):
trace.main.attrs_tracked.append((obj, attr))
setattr(obj, attr, tracer)
ad.JVPTrace.process_setattr = _setattr_jvp


def linearize(f, *primals, attrs: list[tuple[Any, str]] = []):
attr_primals = [jax_getattr(o, a) for o, a in attrs]
attr_avals = [core.raise_to_shaped(core.get_aval(p)) for p in attr_primals]
primals_flat, in_tree = tree_flatten(primals)
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
f_, *attr_primals, *primals_flat)
f_lin = _lin_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
attrs, attrs_out)
return tree_unflatten(out_tree(), primal_out), f_lin

def _linearize(traceable: lu.WrappedFun, *primals):
jvpfun, attrs = _split_attrs(_jvp(traceable))
in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
+ tuple(pe.PartialVal.unknown(core.get_aval(p).at_least_vspace())
for p in primals))
_, in_tree = tree_flatten((primals, primals))
jvpfun_flat, out_tree = flatten_fun_nokwargs(jvpfun, in_tree)
jaxpr, out_pvals, consts = pe.trace_to_jaxpr_nounits(jvpfun_flat, in_pvals)
out_primals_pvals, out_tangents_pvals, out_tangent_attr_pvals = \
tree_unflatten(out_tree(), out_pvals)
out_primals_consts = [pval.get_known() for pval in out_primals_pvals]
return (out_primals_consts, [*out_tangents_pvals, *out_tangent_attr_pvals],
jaxpr, consts, attrs())

@lu.transformation_with_aux
def _split_attrs(*args, **kwargs):
primals, tangents, tangent_attrs = yield args, kwargs
attrs, tangent_attr_vals = unzip2(((o, a), t) for o, a, t in tangent_attrs)
yield (primals, tangents, tangent_attr_vals), attrs

def _lin_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs):
in_tree, out_tree = io_tree
def f_lin(*tangents, attr_tangents):
if set(attr_tangents) - set(in_attrs): raise Exception
tangents_, in_tree_ = tree_flatten(tangents)
assert in_tree == in_tree_
attr_tangents_ = [attr_tangents.get(a, ad.Zero(aval))
for a, aval in zip(in_attrs, attr_avals)]
out = core.eval_jaxpr(jaxpr, consts, *attr_tangents_, *tangents_)
out_ = iter(out)
out = [p.get_known() if p.is_known() else next(out_) for p in out_pvals]
assert next(out_, None) is None
tangents_out, attr_tangents_out = split_list(out, [len(out)-len(out_attrs)])
out_ct = tree_unflatten(out_tree, tangents_out)
return out_ct, dict(zip(out_attrs, attr_tangents_out))
return f_lin


def vjp(f, *primals, attrs: list[tuple[Any, str]] = []):
attr_primals = [jax_getattr(o, a) for o, a in attrs]
primals_flat, in_tree = tree_flatten(primals)
tree = treedef_tuple((tree_structure(attr_primals), *in_tree.children()))
f_, out_tree = flatten_fun_nokwargs(_set_attrs(lu.wrap_init(f), attrs), tree)
primal_out, out_pvals, jaxpr, consts, attrs_out = _linearize(
f_, *attr_primals, *primals_flat)
attr_avals = [core.raise_to_shaped(core.get_aval(jax_getattr(o, a))).at_least_vspace()
for o, a in attrs_out]
f_vjp = _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, (in_tree, out_tree()),
attrs, attrs_out)
return tree_unflatten(out_tree(), primal_out), f_vjp

def _vjp_wrap(jaxpr, consts, out_pvals, attr_avals, io_tree, in_attrs, out_attrs):
in_tree, out_tree = io_tree
dummies = [ad.UndefinedPrimal(v.aval) for v in jaxpr.invars]
def f_vjp(out_ct, *, attr_cotangents: dict[tuple[Any, str], JaxVal] = {}):
out_cts, out_tree_ = tree_flatten(out_ct)
assert out_tree == out_tree_
attr_cts = [attr_cotangents.get(a, ad.Zero(aval))
for a, aval in zip(out_attrs, attr_avals)]
out = ad.backward_pass(jaxpr, (), (), consts, dummies, (*out_cts, *attr_cts))
in_attr_bars, arg_cts = split_list(out, [len(in_attrs)])
args_ct = tree_unflatten(in_tree, map(ad.instantiate_zeros, arg_cts))
return args_ct, dict(zip(in_attrs, in_attr_bars))
return f_vjp
168 changes: 167 additions & 1 deletion tests/attrs_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,10 @@
@dataclass
class Thing:
x: float
__hash__ = object.__hash__
__eq__ = object.__eq__

attrs.register(Thing)
attrs.register(Thing) # enables passing as arg into jitted function

class AttrsTest(jtu.JaxTestCase):

Expand Down Expand Up @@ -366,6 +368,170 @@ def g_ref(x, x_dot, y, y_dot):
self.assertAllClose(w_ddot, w_ddot_, check_dtypes=False)
self.assertAllClose(z_ddot, z_ddot_, check_dtypes=False)

class AttrsLinTest(jtu.JaxTestCase):

@parameterized.parameters([True, False])
def test_attr_output(self, jit):
thing = Thing(1.0)

def f(x, _):
y = jnp.sin(x)
jax_setattr(thing, 'x', y)

if jit:
f = jax.jit(f)

out, f_lin = attrs.linearize(f, 3.0, 4.0)
self.assertIsNone(out)
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)

out_dot, attr_tangents = f_lin(1.0, 2.0, attr_tangents={})
self.assertIsNone(out_dot)
self.assertAllClose(thing.x, jnp.sin(3.0)) # didn't change
self.assertLen(attr_tangents, 1)
self.assertAllClose(attr_tangents[(thing, 'x')], jnp.cos(3.0),
check_dtypes=False)

@parameterized.parameters([True, False])
def test_attr_input(self, jit):
thing = Thing(1.0)

def f():
x = jax_getattr(thing, 'x')
return jnp.sin(x)

if jit:
f = jax.jit(f)

out, f_lin = attrs.linearize(f, attrs=[(thing, 'x')])
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)

out_dot, attr_tangents = f_lin(attr_tangents={(thing, 'x'): 2.0})
self.assertAllClose(out_dot, 2. * jnp.cos(1.0), check_dtypes=False)
self.assertLen(attr_tangents, 1)
self.assertAllClose(attr_tangents[(thing, 'x')], 2.0, check_dtypes=False)

@parameterized.parameters([True, False])
def test_attr_inout(self, jit):
thing1 = Thing(1.0)
thing2 = Thing(2.0)

def f(x, y):
z = jax_getattr(thing1, 'x')
w = jax_getattr(thing2, 'x')
out = jnp.sin(x * y * z * w)
jax_setattr(thing1, 'x', out)
jax_setattr(thing2, 'x', 2 * out)
return 3 * out, 4 * out

if jit:
f = jax.jit(f)

def f_ref(x, y, z, w):
out = jnp.sin(x * y * z * w)
return (3 * out, 4 * out), (out, 2 * out)

out, f_lin = attrs.linearize(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
expected = (3 * jnp.sin(1. * 2. * 3. * 4.),
4 * jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(out, expected, check_dtypes=False)
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))

(out_ref, state_out_ref), f_lin_ref = jax.linearize(f_ref, 3., 4., 1., 2.)
self.assertAllClose(out, out_ref, check_dtypes=False)
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)

out_dot, attr_tangents = f_lin(1., 2.,
attr_tangents={(thing1, 'x'): 5.,
(thing2, 'x'): 6.})
self.assertAllClose(thing1.x, jnp.sin(1. * 2. * 3. * 4.))
self.assertAllClose(thing2.x, 2 * jnp.sin(1. * 2. * 3. * 4.))
(out_dot_ref, state_dot_ref) = f_lin_ref(1., 2., 5., 6.)
self.assertAllClose(out_dot, out_dot_ref, check_dtypes=False)
self.assertLen(attr_tangents, 2)
self.assertAllClose(attr_tangents[(thing1, 'x')], state_dot_ref[0],
check_dtypes=False)
self.assertAllClose(attr_tangents[(thing2, 'x')], state_dot_ref[1],
check_dtypes=False)

class AttrsVJPTest(jtu.JaxTestCase):

@parameterized.parameters([True, False])
def test_attr_input(self, jit):
thing = Thing(1.0)

def f():
x = jax_getattr(thing, 'x')
return jnp.sin(x)

if jit:
f = jax.jit(f)

out, f_vjp = attrs.vjp(f, attrs=[(thing, 'x')])
self.assertAllClose(out, jnp.sin(1.0), check_dtypes=False)

arg_cts, attr_cotangents = f_vjp(1.0)
self.assertEqual(arg_cts, ())
self.assertLen(attr_cotangents, 1)
self.assertAllClose(attr_cotangents[(thing, 'x')], jnp.cos(1.0),
check_dtypes=False)

@parameterized.parameters([True, False])
def test_attr_output(self, jit):
thing = Thing(1.0)

def f(x, _):
y = jnp.sin(x)
jax_setattr(thing, 'x', y)

if jit:
f = jax.jit(f)

out, f_vjp = attrs.vjp(f, 3.0, 4.0)
self.assertIsNone(out)
self.assertAllClose(thing.x, jnp.sin(3.0), check_dtypes=False)

arg_cts, attr_cotangents = f_vjp(None, attr_cotangents={(thing, 'x'): 2.0})
self.assertAllClose(arg_cts, (2 * jnp.cos(3.0), 0.), check_dtypes=False)
self.assertLen(attr_cotangents, 0)

@parameterized.parameters([True, False])
def test_attr_inout(self, jit):
thing1 = Thing(1.0)
thing2 = Thing(2.0)

def f(x, y):
z = jax_getattr(thing1, 'x')
w = jax_getattr(thing2, 'x')
out = jnp.sin(x * y * z * w)
jax_setattr(thing1, 'x', out)
jax_setattr(thing2, 'x', 2 * out)
return 3 * out, 4 * out

if jit:
f = jax.jit(f)

def f_ref(x, y, z, w):
out = jnp.sin(x * y * z * w)
return (3 * out, 4 * out), (out, 2 * out)

out, f_vjp = attrs.vjp(f, 3., 4., attrs=[(thing1, 'x'), (thing2, 'x')])
(out_ref, state_out_ref), f_vjp_ref = jax.vjp(f_ref, 3., 4., 1., 2.)
self.assertAllClose(out, out_ref, check_dtypes=False)
self.assertAllClose((thing1.x, thing2.x), state_out_ref, check_dtypes=False)

in_bar, attr_cotangents = f_vjp((1., 2.),
attr_cotangents={(thing1, 'x'): 5.,
(thing2, 'x'): 6.})
in_bar_ref_ = f_vjp_ref(((1., 2.), (5., 6.)))
in_bar_ref, attr_cotangents_ref = in_bar_ref_[:2], in_bar_ref_[2:]
self.assertAllClose(in_bar, in_bar_ref, check_dtypes=False)
self.assertLen(attr_cotangents, 2)
self.assertAllClose(attr_cotangents[(thing1, 'x')], attr_cotangents_ref[0],
check_dtypes=False)
self.assertAllClose(attr_cotangents[(thing2, 'x')], attr_cotangents_ref[1],
check_dtypes=False)

if __name__ == '__main__':
absltest.main(testLoader=jtu.JaxTestLoader())

0 comments on commit b0b88d8

Please sign in to comment.