v0.8.0
V0.8
Additions:
- Add
on_getattr
infield
to apply function on__getattr__
Breaking changes:
- Rename
callbacks
infield
toon_setattr
to matchattrs
and better reflect its functionality.
These changes enable:
-
stricter data validation on instance values, as in the following example:
on_setattr
ensure the value is of certain type (e.g.integer) during initialization, andon_getattr
, ensure the value is of certain type (e.g. integer) whenever its accessed.import pytreeclass as pytc import jax def assert_int(x): assert isinstance(x, int), "must be an int" return x @pytc.autoinit class Tree(pytc.TreeClass): a: int = pytc.field(on_getattr=[assert_int], on_setattr=[assert_int]) def __call__(self, x): # enusre `a` is an int before using it in computation by calling `assert_int` a: int = self.a return a + x tree = Tree(a=1) print(tree(1.0)) # 2.0 tree = jax.tree_map(lambda x: x + 0.0, tree) # make `a` a float tree(1.0) # AssertionError: must be an int
-
Frozen field without using
tree_mask
/tree_unmask
The following shows a pattern where the value is frozen on
__setattr__
and unfrozen whenever accessed, this ensures thatjax
transformation does not see the value. the following example showcase this functionalityimport pytreeclass as pytc import jax @pytc.autoinit class Tree(pytc.TreeClass): frozen_a : int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze]) def __call__(self, x): return self.frozen_a + x tree = Tree(frozen_a=1) # 1 is non-jaxtype # can be used in jax transformations @jax.jit def f(tree, x): return tree(x) f(tree, 1.0) # 2.0 grads = jax.grad(f)(tree, 1.0) # Tree(frozen_a=#1)
Compared with other libraies that implements
static_field
, this pattern has lower overhead and does not altertree_flatten
/tree_unflatten
methods of the tree. -
Easier way to create a buffer (non-trainable array)
Just use
jax.lax.stop_gradient
inon_getattr
import pytreeclass as pytc import jax import jax.numpy as jnp def assert_array(x): assert isinstance(x, jax.Array) return x @pytc.autoinit class Tree(pytc.TreeClass): buffer: jax.Array = pytc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array]) def __call__(self, x): return self.buffer**x tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0])) tree(2.0) # Array([1., 4., 9.], dtype=float32) @jax.jit def f(tree, x): return jnp.sum(tree(x)) f(tree, 1.0) # Array([1., 2., 3.], dtype=float32) print(jax.grad(f)(tree, 1.0)) # Tree(buffer=[0. 0. 0.])