Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Frozen / Static leaves #67

Open
adam-hartshorne opened this issue May 20, 2023 · 5 comments
Open

Frozen / Static leaves #67

adam-hartshorne opened this issue May 20, 2023 · 5 comments
Assignees
Labels
documentation Improvements or additions to documentation question Further information is requested

Comments

@adam-hartshorne
Copy link

adam-hartshorne commented May 20, 2023

Reading the documentation, I understand that you can freeze variables by using a mask based upon name or type.

Is it possible to set a variable to "frozen" within the class definition i.e in the way Equinox has static_field option.

While I understand the concept behind being able to mask out a set of variables contained in a PyTree (or PyTree of PyTrees), there are lots of situations where you know when creating a new class, that certain variables will only ever be constant. Furthermore, as models become much more complicated (or if others may utilise elements of your model) it becomes more cumbersome to have to mask these out / others have to know to do this.

@ASEM000 ASEM000 added documentation Improvements or additions to documentation question Further information is requested labels May 20, 2023
@ASEM000 ASEM000 self-assigned this May 20, 2023
@ASEM000
Copy link
Owner

ASEM000 commented May 20, 2023

As of version 0.8.0

TLDR;

as of version 0.8.0 use

import pytreeclass as pytc
import jax


@pytc.autoinit
class Tree(pytc.TreeClass):
    frozen_a: int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])

    def __call__(self, x):
        return self.frozen_a + x


tree = Tree(frozen_a=1)  # 1 is non-jaxtype
# can be used in jax transformations


@jax.jit
def f(tree, x):
    return tree(x)


print(f(tree, 1.0))  # 2.0
print(jax.grad(f)(tree, 1.0))  # Tree(frozen_a=#1)
print(jax.tree_util.tree_leaves(tree))  # []

More details into about the freezing/unfreezing mechanism:

If you prefer manual masking, you could apply pytc.freeze on the value directly. But you have to use is_leaf=pytc.is_frozen if you want to interact with this value using tree_map

Using this style, the end user will only have to unmask before calling. At the same time, having access to the masked values using is_leaf=pytc.is_frozen.

You can do something like this:

Style 1: with no init body, callbacks here is a list of functions applied on your in_features before setting it to the instance.

import pytreeclass as pytc
class Tree(pytc.TreeClass):
     in_features: int = pytc.field(callbacks=[pytc.freeze])

Style 2: with init body

class Tree(pytc.TreeClass):
    def __init__(self, in_features: int):
        # Some logic using in_features
        # ...

        # Lastly you freeze it
        self.in_features = pytc.freeze(in_features)

    def __call__(self, x:float):
        return x * self.in_features

t1 = Tree(2)

@jax.value_and_grad
def jax_func(tree:Tree):
    tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
    return tree(1.0)

jax_func(t1)
# (2.0, Tree(in_features=#2)) # ->`#` is frozen marker

For background, an earlier version of pytreeclass had static field-like behaviour, but this has three problems:

1 . Even if these fields are constants, Using static_field, you will lose the ability to filter your models based on that always-non-trainable field using jax.tree_map.
2. .at uses jax.tree_map under the hood, if I let the user designate a permanently static field, then this will have an asymmetric design. For example, if you select a as a static field for model nn, then nn.a will work while nn.at['a'].get() will not work at all.
3. static_field will lead to repetitive code because you have to declare it twice as a field and inside the init body. something like this: (from equinox conv code)

class Conv(Module):
    """General N-dimensional convolution."""

    num_spatial_dims: int = static_field()
    weight: Array
    bias: Optional[Array]
    in_channels: int = static_field()
    out_channels: int = static_field()
    kernel_size: Tuple[int, ...] = static_field()
    stride: Tuple[int, ...] = static_field()
    padding: Tuple[Tuple[int, int], ...] = static_field()
    dilation: Tuple[int, ...] = static_field()
    groups: int = static_field()
    use_bias: bool = static_field()

    def __init__(
        self,
        num_spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[int, Sequence[int]],
        stride: Union[int, Sequence[int]] = 1,
        padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0,
        dilation: Union[int, Sequence[int]] = 1,
        groups: int = 1,
        use_bias: bool = True,
        *,
        key: PRNGKey,
        **kwargs,
    ):

This gets worse as you write more and more code.

Lastly, pytc.freeze is just a pytree with no leaves yielded during the flattening rule. So you can use pytc.freeze on any pytree ( no special treatment inside a TreeClass ).

This design eliminates static field logic during the flattening/unflattening of a tree, leading to faster flattening/unflattening for non-masked trees and simplifying the code.
Let me know if this answers your question.

@adam-hartshorne
Copy link
Author

adam-hartshorne commented May 22, 2023

So if I am understanding this correctly,

tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)

needs to be called prior to any call to a pytc.TreeClass containing static / frozen variables. So if I have a class within a class within a class all containing frozen variables (or a class that contains numerous other classes which utilise frozen variables), for each call to methods of that class they must have this unfreezing.

That doesn't seem ideal, when you start to get much more complicated models or wish to build a library of functions (as every class would need to be "wrapped" to hide this from the user).

@ASEM000
Copy link
Owner

ASEM000 commented May 22, 2023

For a deeply nested instance with frozen attributes all over the place, you need to write it once (usually inside your loss function) , something like this.

from typing import Any
import pytreeclass as pytc
import jax


class A(pytc.TreeClass):
    a: int = pytc.freeze(1)
    b: float = 2.0

    def __call__(self, x):
        return self.a * x + self.b


class B(pytc.TreeClass):
    c: int = pytc.freeze(1)
    d: A = A()

    def __call__(self, x):
        return self.c * x + self.d(x)

b = B()
# B(c=#1, d=A(a=#1, b=2))

@jax.jit
@jax.value_and_grad
def loss_func(b: B):
    b = jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen)
    return b(1.0)


loss_func(b)
# (Array(4., dtype=float32, weak_type=True),
#  B(c=#1, d=A(a=#1, b=f32[](μ=1.00, σ=0.00, ∈[1.00,1.00]))))

For comparison, under the hood, equinox filter decorated functions do something similar on two steps:
first equinox splits the tree to trainable/non-trainable parts before the Jax boundary, then combines it inside the jax function for each call. pytreeclass scheme should be faster because you only do one step.

import equinox as eqx 
import jax 
import pytreeclass as pytc
import jax.numpy as jnp

class TreeEqx(eqx.Module):
    a:int  = eqx.static_field(default=1)
    b:jax.Array = jnp.array(1.)

class TreePyTC(pytc.TreeClass):
    a:int  = pytc.freeze(1)
    b:jax.Array = jnp.array(1.)

tree = TreePyTC()

@jax.jit
def some_func(t):
    t = jax.tree_map(pytc.unfreeze, t, is_leaf=pytc.is_frozen)
    return t.a + t.b

%timeit some_func(tree)
# 12.1 µs ± 836 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

tree = TreeEqx()

@eqx.filter_jit
def some_func(t):
    return t.a + t.b

%timeit some_func(tree)
# 26.7 µs ± 5.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

let me know if you have any questions.

@adam-hartshorne
Copy link
Author

Arh ok, I understand. That is obviously more manageable.

I wonder if adding a wrapper function/decorator to hide this from users might be useful? The equinox decorated functions are very useful in this respect of hiding the complexity away.

I can see occurrences where somebody might want to use your model and try a different loss function, or incorporate your model / NN into a pipeline of others and they don't realise this behaviour. The ability to wrap your model such that another user doesn't even need to think about this jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen) might prove helpful in stopping obvious mistakes.

@ASEM000
Copy link
Owner

ASEM000 commented May 22, 2023

You are right; fortunately, it's easy to do just that.

def unfreeze_func(func):
    @ft.wraps(func)
    def wrapper(tree, *a, **k):
        tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
        return func(tree, *a, **k)

    return wrapper


@jax.jit
@jax.value_and_grad
@unfreeze_func
def loss_func(b: B):
    return b(1.0)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants