-
Notifications
You must be signed in to change notification settings - Fork 2
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Frozen / Static leaves #67
Comments
As of version 0.8.0 TLDR; as of version import pytreeclass as pytc
import jax
@pytc.autoinit
class Tree(pytc.TreeClass):
frozen_a: int = pytc.field(on_getattr=[pytc.unfreeze], on_setattr=[pytc.freeze])
def __call__(self, x):
return self.frozen_a + x
tree = Tree(frozen_a=1) # 1 is non-jaxtype
# can be used in jax transformations
@jax.jit
def f(tree, x):
return tree(x)
print(f(tree, 1.0)) # 2.0
print(jax.grad(f)(tree, 1.0)) # Tree(frozen_a=#1)
print(jax.tree_util.tree_leaves(tree)) # [] More details into about the freezing/unfreezing mechanism: If you prefer manual masking, you could apply Using this style, the end user will only have to unmask before calling. At the same time, having access to the masked values using You can do something like this: Style 1: with no init body, import pytreeclass as pytc
class Tree(pytc.TreeClass):
in_features: int = pytc.field(callbacks=[pytc.freeze]) Style 2: with init body class Tree(pytc.TreeClass):
def __init__(self, in_features: int):
# Some logic using in_features
# ...
# Lastly you freeze it
self.in_features = pytc.freeze(in_features)
def __call__(self, x:float):
return x * self.in_features
t1 = Tree(2)
@jax.value_and_grad
def jax_func(tree:Tree):
tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
return tree(1.0)
jax_func(t1)
# (2.0, Tree(in_features=#2)) # ->`#` is frozen marker For background, an earlier version of pytreeclass had static field-like behaviour, but this has three problems: 1 . Even if these fields are constants, Using class Conv(Module):
"""General N-dimensional convolution."""
num_spatial_dims: int = static_field()
weight: Array
bias: Optional[Array]
in_channels: int = static_field()
out_channels: int = static_field()
kernel_size: Tuple[int, ...] = static_field()
stride: Tuple[int, ...] = static_field()
padding: Tuple[Tuple[int, int], ...] = static_field()
dilation: Tuple[int, ...] = static_field()
groups: int = static_field()
use_bias: bool = static_field()
def __init__(
self,
num_spatial_dims: int,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Sequence[int]],
stride: Union[int, Sequence[int]] = 1,
padding: Union[int, Sequence[int], Sequence[Tuple[int, int]]] = 0,
dilation: Union[int, Sequence[int]] = 1,
groups: int = 1,
use_bias: bool = True,
*,
key: PRNGKey,
**kwargs,
): This gets worse as you write more and more code. Lastly, This design eliminates static field logic during the flattening/unflattening of a tree, leading to faster flattening/unflattening for non-masked trees and simplifying the code. |
So if I am understanding this correctly,
needs to be called prior to any call to a pytc.TreeClass containing static / frozen variables. So if I have a class within a class within a class all containing frozen variables (or a class that contains numerous other classes which utilise frozen variables), for each call to methods of that class they must have this unfreezing. That doesn't seem ideal, when you start to get much more complicated models or wish to build a library of functions (as every class would need to be "wrapped" to hide this from the user). |
For a deeply nested instance with frozen attributes all over the place, you need to write it once (usually inside your loss function) , something like this. from typing import Any
import pytreeclass as pytc
import jax
class A(pytc.TreeClass):
a: int = pytc.freeze(1)
b: float = 2.0
def __call__(self, x):
return self.a * x + self.b
class B(pytc.TreeClass):
c: int = pytc.freeze(1)
d: A = A()
def __call__(self, x):
return self.c * x + self.d(x)
b = B()
# B(c=#1, d=A(a=#1, b=2))
@jax.jit
@jax.value_and_grad
def loss_func(b: B):
b = jax.tree_map(pytc.unfreeze, b, is_leaf=pytc.is_frozen)
return b(1.0)
loss_func(b)
# (Array(4., dtype=float32, weak_type=True),
# B(c=#1, d=A(a=#1, b=f32[](μ=1.00, σ=0.00, ∈[1.00,1.00])))) For comparison, under the hood, import equinox as eqx
import jax
import pytreeclass as pytc
import jax.numpy as jnp
class TreeEqx(eqx.Module):
a:int = eqx.static_field(default=1)
b:jax.Array = jnp.array(1.)
class TreePyTC(pytc.TreeClass):
a:int = pytc.freeze(1)
b:jax.Array = jnp.array(1.)
tree = TreePyTC()
@jax.jit
def some_func(t):
t = jax.tree_map(pytc.unfreeze, t, is_leaf=pytc.is_frozen)
return t.a + t.b
%timeit some_func(tree)
# 12.1 µs ± 836 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
tree = TreeEqx()
@eqx.filter_jit
def some_func(t):
return t.a + t.b
%timeit some_func(tree)
# 26.7 µs ± 5.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each) let me know if you have any questions. |
Arh ok, I understand. That is obviously more manageable. I wonder if adding a wrapper function/decorator to hide this from users might be useful? The equinox decorated functions are very useful in this respect of hiding the complexity away. I can see occurrences where somebody might want to use your model and try a different loss function, or incorporate your model / NN into a pipeline of others and they don't realise this behaviour. The ability to wrap your model such that another user doesn't even need to think about this |
You are right; fortunately, it's easy to do just that. def unfreeze_func(func):
@ft.wraps(func)
def wrapper(tree, *a, **k):
tree = jax.tree_map(pytc.unfreeze, tree, is_leaf=pytc.is_frozen)
return func(tree, *a, **k)
return wrapper
@jax.jit
@jax.value_and_grad
@unfreeze_func
def loss_func(b: B):
return b(1.0)
|
Reading the documentation, I understand that you can freeze variables by using a mask based upon name or type.
Is it possible to set a variable to "frozen" within the class definition i.e in the way Equinox has static_field option.
While I understand the concept behind being able to mask out a set of variables contained in a PyTree (or PyTree of PyTrees), there are lots of situations where you know when creating a new class, that certain variables will only ever be constant. Furthermore, as models become much more complicated (or if others may utilise elements of your model) it becomes more cumbersome to have to mask these out / others have to know to do this.
The text was updated successfully, but these errors were encountered: