Skip to content

v0.8.0

Compare
Choose a tag to compare
@ASEM000 ASEM000 released this 06 Sep 10:50
· 38 commits to main since this release
e9deb51

V0.8

Additions:

  • Add on_getattr in field to apply function on __getattr__

Breaking changes:

  • Rename callbacks in field to on_setattr to match attrs and better reflect its functionality.

These changes enable:

  1. stricter data validation on instance values, as in the following example:

    on_setattr ensure the value is of certain type (e.g.integer) during initialization, and on_getattr, ensure the value is of certain type (e.g. integer) whenever its accessed.

    import pytreeclass as pytc
    import jax
    
    def assert_int(x):
        assert isinstance(x, int), "must be an int"
        return x
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        a: int = pytc.field(on_getattr=[assert_int], on_setattr=[assert_int])
    
        def __call__(self, x):
            # enusre `a` is an int before using it in computation by calling `assert_int`
            a: int = self.a
            return a + x
    
    tree = Tree(a=1)
    print(tree(1.0))  # 2.0
    tree = jax.tree_map(lambda x: x + 0.0, tree)  # make `a` a float
    tree(1.0)  # AssertionError: must be an int
  2. Frozen field without using tree_mask/tree_unmask

    The following shows a pattern where the value is frozen on __setattr__ and unfrozen whenever accessed, this ensures that jax transformation does not see the value. the following example showcase this functionality

    import pytreeclass as pytc
    import jax
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        frozen_a : int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])
    
        def __call__(self, x):
            return self.frozen_a + x
    
    tree = Tree(frozen_a=1)  # 1 is non-jaxtype
    # can be used in jax transformations
    
    @jax.jit
    def f(tree, x):
        return tree(x)
    
    f(tree, 1.0)  # 2.0
    
    grads = jax.grad(f)(tree, 1.0)  # Tree(frozen_a=#1)

    Compared with other libraies that implements static_field, this pattern has lower overhead and does not alter tree_flatten/tree_unflatten methods of the tree.

  3. Easier way to create a buffer (non-trainable array)

    Just use jax.lax.stop_gradient in on_getattr

    import pytreeclass as pytc
    import jax
    import jax.numpy as jnp
    
    def assert_array(x):
        assert isinstance(x, jax.Array)
        return x
    
    @pytc.autoinit
    class Tree(pytc.TreeClass):
        buffer: jax.Array = pytc.field(on_getattr=[jax.lax.stop_gradient],on_setattr=[assert_array])
        def __call__(self, x):
            return self.buffer**x
        
    tree = Tree(buffer=jnp.array([1.0, 2.0, 3.0]))
    tree(2.0)  # Array([1., 4., 9.], dtype=float32)
    @jax.jit
    def f(tree, x):
        return jnp.sum(tree(x))
    
    f(tree, 1.0)  # Array([1., 2., 3.], dtype=float32)
    print(jax.grad(f)(tree, 1.0))  # Tree(buffer=[0. 0. 0.])