-
Notifications
You must be signed in to change notification settings - Fork 1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Path API #26
Comments
Hey, I really appreciate you reaching out! Yeah I remember chatting over email a while ago about this stuff when both PyTreeClass and Zodiax were in early dev. Apologies for taking so long to get to this I have had too many things on my plate lately and have wanted to engage with this properly. This is really cool and a bit 🪄 magic 🪄 to me right now, it seems like you've solved the I'm curious as to how this would extend to nested classes, where pytree leaves are other pytrees, dictionaries etc. My guess looking at this example is that you would need to nest the model = model.set('telescope.optics.aberrations.coefficients', new_coefficients) or to operate on multiple specific leaves at the same time: model = model.multiply(['leaf1', 'leaf2'], 2) I'd love your thoughts on operations like this within the PyTreeClass framework! |
I also just had a look at your benchmarks against Equinox, I'm very curious as to why PyTreeClass outperforms Equinox and others! I would have thought the underlying computations would be virtually the same once compiled to XLA. |
You can use tree = {"a": {"b": 1, "c": 2}, "d": 3}
where = ("a", "b")
print(pytc.AtIndexer(tree, where).get())
# {'a': {'b': 1, 'c': None}, 'd': None} Here each tuple item represents a matching key for each level, so ("a", "b") means match tree["a"]["b"]
You can do this by using a tuple of tuples, tree = {"a": {"b": 1}, "c": 2, "d": 3}
where = (("a", "c"),)
print(pytc.AtIndexer(tree, where).get())
# {'a': {'b': 1}, 'c': 2, 'd': None} Here the first item in the where tuple represents the first level, and the two-item tuple means matches any of them For For more check here and here to extend it for custom matching strategies . Note that, because I use |
For example, consider the example for a pytree of For the The following is the complete benchmark code I used and the results on colab cpu. Feel free to let me know if you have any questions. import jax
import jax.numpy as jnp
import pytreeclass as pytc
import equinox as eqx
import optax
class PyTCLinear(pytc.TreeClass):
def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name: str):
self.name = name
self.weight = jax.random.normal(key, (in_dim, out_dim))
self.bias = jax.numpy.array(0.0)
def __call__(self, x: jax.Array):
return x @ self.weight + self.bias
class EqxLinear(eqx.Module):
name: str
weight: jax.Array
bias: jax.Array
def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name:str):
self.name = name
self.weight = jax.random.normal(key, (in_dim, out_dim))
self.bias = jax.numpy.array(0.0)
def __call__(self, x: jax.Array):
return x @ self.weight + self.bias
def sequential_linears(layers, x):
*layers, last = layers
for layer in layers:
x = layer(x)
x = jax.nn.relu(x)
return last(x)
x = jnp.linspace(100, 1)[:, None]
y = x**2
key = jax.random.PRNGKey(0)
optim = optax.adam(1e-3)
@jax.value_and_grad
def pytc_loss_func(layers,x,y):
layers = pytc.tree_unmask(layers)
y = sequential_linears(layers, x)
return jnp.mean((x-y)**2)
@jax.jit
def pytc_train_step(layers, optim_state, x,y):
loss, grads = pytc_loss_func(layers,x,y)
updates, optim_state= optim.update(grads, optim_state)
layers = optax.apply_updates(layers, updates)
return layers, optim_state, loss
@eqx.filter_value_and_grad
def eqx_loss_func(layers,x,y):
y = sequential_linears(layers, x)
return jnp.mean((x-y)**2)
@eqx.filter_jit
def eqx_train_step(layers, optim_state, x,y):
loss, grads = eqx_loss_func(layers,x,y)
updates, optim_state= optim.update(grads, optim_state)
layers = eqx.apply_updates(layers, updates)
return layers, optim_state, loss
def pytc_train(layers, optim_state, x,y, epochs=100):
for _ in range(epochs):
layers, optim_state, loss = pytc_train_step(layers,optim_state, x,y)
return layers, loss
def eqx_train(layers,optim_state, x,y, epochs=100):
for _ in range(epochs):
layers, optim_state, loss = eqx_train_step(layers,optim_state, x,y)
return layers, loss
for linear_count in [10,100]:
pytc_linears = [PyTCLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
# mask non-differentiable parameters
pytc_linears = pytc.tree_mask(pytc_linears)
pytc_optim_state = optim.init(pytc_linears)
eqx_linears = [EqxLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
eqx_optim_state = optim.init(eqx.filter(eqx_linears, eqx.is_array))
pytc_linears , pytc_loss = pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=1000)
eqx_linears, eqx_loss = eqx_train(eqx_linears, eqx_optim_state, x,y, epochs=1000)
assert pytc_loss == eqx_loss
time_pytc = %timeit -o pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=100)
time_eqx = %timeit -o eqx_train(eqx_linears, eqx_optim_state, x,y, epochs=100)
print(f"Eqx/PyTc: {time_eqx.average/time_pytc.average} for {linear_count} layers")
|
Thanks for the quick response! Yeah it looks like the Interesting, thanks for this example! So if I've understood correctly you've made accessing attributes requires less tree operations, resulting in the speed gains not from the underlying computation but more efficient operation on the pytree itself? Would this result in a generic speed increase for updating pytrees generally? This has been a point in the back of my mind for a while, since in my work its often necessary to update singular values large pytrees, resulting in slow-downs from loading the pytree into a new instance. I've always thought my looping/nesting implantation of accessing leaves could be improved. Another thing that do that breaks from the more general ML use cases is take gradients with respect to either a single, or small subset of parameters (since we are modelling physical objects, we know that things like the diameter of our telescope isn't going to change 😝). This used to be a default(ish) behaviour of Equinox's I also see that you've implemented a equivalent of One of the primary goals of Zodiax is to make working with the objects require as little knowledge as possible. As my target is a science audience most of the users have little to no programming background and are mostly self-taught (no wonder why most scientific code has quirks, mine included 🙃). That was a large motivator for Zodiax, to hide users from needing to learn about Thanks again for these responses, and sorry for hammering you with so many questions, it seems like we're thinking about the same stuff! |
No,
If you mean the speed up with equinox, then this is purely due to the design of
see here
There is a meaningful difference between the two for your use case:
I think this is XY problem. Can you give me an MWE of your desired functionality? |
Ah yes I suspected that using the Jax Yeah it seems like overall
Does this mean that Jax will understand these array leaves as 'static' under jit transformations? They obviously can't be treated as actually static since they aren't hashable, but does this lead to a meaningful difference under the hood? We have always though that speeds might be improved by preventing Jax tracing these arrays that can be considered 'static'. What are your thoughts on a filtered-transformation equivalent for
Yes it more or less is. This is not necessarily needed functionality, but more an idea I've been toying with as a way to define and update some units parameter of leaves or pytrees. Here's a MWE from jax import Array
import pytreeclass as pytc
class Foo(pytc.TreeClass):
leaf : Array
unit : str = pytc.static_leaf() # Something similar to this
def __call__(self, new_unit: str):
new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function
return self.at[('leaf', 'unit')].set((new_value, new_unit)) The idea here is that since the leaves are already treated as 'static' you avoid the need to mask/unmask in functions. The other option would be a filtered-transform type operation: from jax import Array
import pytreeclass as pytc
class Foo(pytc.TreeClass):
leaf : Array
unit : str
def __call__(self, new_unit: str):
new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function
return self.at[('leaf', 'unit')].set((new_value, new_unit))
param = 'leaf'
@filter_jit(param)
@filter_grad(param)
def f(pytree):
return pytree('nm').leaf**2 The idea here is that the filtered transformation automatically masks all leaves except for the Generally, the goal is to have 'mask-by-path' behavior as this is more in-line with the goals of scientific models. |
Its easier with import pytreeclass as pytc
import jax.numpy as jnp
import jax
class Tree(pytc.TreeClass):
def __init__(self):
self.weight = jnp.array([1.0])
self.bias = jnp.array([1.0])
self.name = "tree" # non-jax type
tree = Tree()
# mask all leaves except weight
mask = tree.at[...].set(True).at["weight"].set(False)
masked_tree = pytc.tree_mask(tree, mask)
@jax.jit
@jax.grad
def loss_func(tree, x: jax.Array):
tree = pytc.tree_unmask(tree)
return jnp.mean(tree.weight * x**2)
grads = loss_func(masked_tree, jnp.array([2.0]))
# Tree(weight=[4.], bias=#[1.], name=#tree)
new_tree = jax.tree_map(lambda x, g: x + g, masked_tree, grads)
new_tree = pytc.tree_unmask(new_tree)
# Tree(weight=[5.], bias=[1.], name=tree)
Yes,
Its harder using equinox style because you have to 1) filter pytrees to work with libraries like optax and then you have to 2) filter over functions. This is a common usability issue in equinox. In contrast, using the above code sample, the
If you want to take the burden of defining static nodes inside the class definition, then the following pattern works : import jax
import pytreeclass as pytc
# changing a "frozen" leaf value without tree_mask/tree_unmask on user side
class Foo(pytc.TreeClass):
def __init__(self, leaf: jax.Array, unit: str):
self.leaf = leaf
self.unit = pytc.freeze(unit) # non-jax type
def change_unit(self, new_unit: str) -> "Foo":
return self.at["unit"].set(pytc.freeze(new_unit), is_leaf=pytc.is_frozen)
foo = Foo(jax.numpy.array([1.0]), "m")
# Foo(leaf=f32[1](μ=1.00, σ=0.00, ∈[1.00,1.00]), unit=#m)
new_foo = foo.change_unit("cm")
# Foo(leaf=f32[1](μ=1.00, σ=0.00, ∈[1.00,1.00]), unit=#cm) Your MWE is a strong case for using pytreeclass over equinox, because in equinox if you define Please let me know what you think. |
Okay so I've had some time to go through the PyTreeClass docs and... wow you have implemented some seriously awesome functionality! Automatic data type and shape checking, I can envision a better version of Zodiax that circumvents the need for wrapping specific Equinox function, using a single default params = ['param1', 'param2']
@jax.jit
@jax.grad
@zdx.filter(params)
def loss(pytree):
return pytree()**2 Ideally this would result in gradients only with respect to the import jax
import jax.numpy as np
from functools import wraps
import pytreeclass as pytc
import equinox as eqx
import zodiax as zdx
# Functional wrapper to build the mask using __getattr__, not super important right now
def set(pytree, parameters, values):
if values is None:
values = [None]
if isinstance(parameters, str):
parameters = [parameters]
new_parameters, new_values = zdx.base._format(parameters, values)
leaves_fn = lambda pytree: zdx.base._get_leaves(pytree, new_parameters)
return eqx.tree_at(leaves_fn, pytree, new_values, is_leaf = lambda leaf: leaf is None)
# Simple class
class Foo(pytc.TreeClass):
x : None
y : None
def __init__(self, x, y):
self.x = np.asarray(x, float)
self.y = np.asarray(y, float)
def __call__(self):
return self.x * self.y
foo = Foo(1, 2)
foo
>>> Foo(x=f32[](μ=1.00, σ=0.00, ∈[1.00,1.00]), y=f32[](μ=2.00, σ=0.00, ∈[2.00,2.00])) And now some code to automate the masking/unmasking def filter(params):
def decorator(func):
def wrapper(pytree, *args, **kwargs):
# Generate Mask
false_mask = pytree.at[...].set(True)
bool_mask = set(false_mask, params, False)
masked_tree = pytc.tree_mask(pytree, bool_mask)
# Generate unmasked func
def unmask(masked):
unmasked = pytc.tree_unmask(masked, bool_mask)
return func(unmasked, *args, **kwargs)
# Mask and call wrapped function
return unmask(masked_tree)
return wrapper
return decorator
params = 'x'
@jax.jit
@jax.value_and_grad
@filter(params)
def loss_fn(pytree):
return np.abs(pytree())**2
loss, grads = loss_fn(foo)
print(loss)
print(grads)
>>> 4.0
>>> Foo(x=8.0, y=4.0) Which is still taking gradients with respect to all parameters. I assume this is simply to do with my construction of the wrapper function? Although it seems like this construction might not work either way, since the parameter passed into the jax transformations will always be the unmasked pytree, with the masking and unmasking happening downstream. I would prefer not to need to write specific wrappers to individual transformations like What are your thoughts on a generalized filter function that can wrap any pytree input based on a mask generated by a list of parameters? Another question I have is when the callbacks of |
I am glad you like it :)
I think my point will be clear if you write a MWE using zodiax with
|
Yeah so my 'ideal' functionality would be something like this import jax
import jax.numpy as np
import pytreeclass as pytc
import zodiax as zdx
import optax
# Simple class
class Foo(pytc.TreeClass):
x : float
y : float
def __init__(self, x, y):
self.x = np.asarray(x, float)
self.y = np.asarray(y, float)
def model(self):
return self.x * self.y
# Construct object
tree = Foo(1, 2)
# Define parameters to optimise
params = 'x'
# Use zodiax to automate mapping of optimisers to leaves
optim, opt_state = zdx.get_optimiser(tree, params, optax.adam(1))
# Define loss and apply a 'filter' to handle the masking/unmasking
@jax.jit
@jax.value_and_grad
@zdx.filter(params)
def loss_fn(pytree):
return np.abs(pytree.model())**2
loss, grads = loss_fn(tree)
# apply the update
updates, opt_state = optim.update(grads, opt_state)
updated_tree = zdx.apply_updates(tree, updates) I can automate the mapping of optax optimisers to leaves so don't worry about that, the main thing is the From my testing of this it seems like that may not be possible since the masking and unmasking need to wrap around the calls of @zdx.filter_jit(params)
@zdx.filter_value_and_grad(params)
def loss_fn(pytree):
return np.abs(pytree.model())**2 But the former example is definitely more ideal since the masking/unmasking would only be done once. Do you think something like that is possible?
Yeah I definitely like the idea of being able to add and remove leaves, however I think there wouldn't be too many use-cases for Zodiax classes. It might even be worth having two base classes for Zodiax, a 'frozen' and 'unfrozen' one that users can then choose which they prefer. If I was to implement this 'frozen' version, can you speak as to how it would integrate with the rest of PyTreeClasss? If everything still works I would consider making something like this. |
For this use case, AFAIK I think you can do it - IMO- in a simpler way: Method 1: using optax masked import pytreeclass as pytc
import jax.numpy as jnp
import optax
@pytc.autoinit
class Tree(pytc.TreeClass):
a: float = jnp.array(1.0)
b: float = jnp.array(2.0)
name: str = "tree" # non-jax type for demonstration
def model(self):
return self.a + self.b
tree = pytc.tree_mask(Tree())
optim = optax.chain(
# update only "a"
optax.masked(optax.sgd(2.0), mask=tree.at[...].set(False).at["a"].set(True)),
# no updates to the rest of the tree
optax.masked(optax.sgd(0.0), mask=tree.at[...].set(True).at["a"].set(False)),
)
optim_state = optim.init(tree)
def loss_func(tree):
ypred = tree.model()
return jnp.mean((ypred - y) ** 2)
@jax.jit
def train_step(tree, optim_state):
grads = jax.grad(loss_func)(tree)
updates, optim_state = optim.update(grads, optim_state)
tree = optax.apply_updates(tree, updates)
return tree, optim_state
def train(tree, optim_state, num_steps):
for _ in range(num_steps):
tree, optim_state = train_step(tree, optim_state)
return tree, optim_state
tree, optim_state = train(tree, optim_state, num_steps=10)
# tree = pytc.tree_unmask(tree)
print(tree)
# Tree(a=148686.31, b=2.0, name=tree)
%timeit train(tree, optim_state, num_steps=1000)
# 9.2 ms ± 134 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Some notes:
Method 2 import pytreeclass as pytc
import jax.numpy as jnp
import optax
@pytc.autoinit
class Tree(pytc.TreeClass):
a: float = jnp.array(1.0)
b: float = jnp.array(2.0)
name: str = "tree" # non-jax type for demonstration
def model(self):
self = pytc.tree_unmask(self)
return self.a + self.b
optim = optax.sgd(2.0)
tree = pytc.tree_mask(Tree(), mask=Tree(a=False, b=True, name=True))
optim_state = optim.init(tree)
def loss_func(tree):
ypred = tree.model()
return jnp.mean((ypred - y) ** 2)
@jax.jit
def train_step(tree, optim_state):
grads = jax.grad(loss_func)(tree)
updates, optim_state = optim.update(grads, optim_state)
tree = optax.apply_updates(tree, updates)
return tree, optim_state
def train(tree, optim_state, num_steps):
for _ in range(num_steps):
tree, optim_state = train_step(tree, optim_state)
return tree, optim_state
tree, optim_state = train(tree, optim_state, num_steps=10)
print(tree)
# Tree(a=148686.31, b=#2.0, name=#tree)
%timeit train(tree, optim_state, num_steps=1000)
# 5.91 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each) Note:
So, I wanted to point out that with these two examples the following:
From the
masking can be used to hide/freeze a leaf, This can be useful to either: 1) pass non-jaxtype or 2) not updating leaf value. optax can achieve 2) but not 1) , however, I would recommend using
But you can change these attributes immutably in import pytreeclass as pytc
@pytc.autoinit
class Tree(pytc.TreeClass):
a: float = 1.0
def add_one(self):
self.a += 1.0
tree= Tree()
tree.add_one() # AttributeError: Cannot set attribute value=2.0 to `key='a'` ...
_, new_tree = tree.at["add_one"]()
print(new_tree)
# Tree(a=2.0) Note that |
Check the new release; the new features are inspired by our discussion here. |
Wow these are some awesome new features, you're a legend! I've done some testing of using PyTreeClass on the back-end, but I've been busy with a re-factor of a downstream library of time. I probably wont have time to put towards this in the near future, but there's definitely features I will integrate soon! |
I came across this library when I referenced your
tree_at
issue with Equinox as one of the motives for PyTreeClass existence. SincePyTreeClass
now supports equinox or any Pytree, and adds functional/composable lenses-like setters/gettersWhy don't you give it another shot and let me know if it solves your problem?
Looking at this library implementation, this looks very similar to the first version of
PyTreeClass
, which had some limitations; however, asjax
now supports path API, I think it's better to migrate to their API.For example:
The text was updated successfully, but these errors were encountered: