Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Path API #26

Open
ASEM000 opened this issue May 24, 2023 · 14 comments
Open

Path API #26

ASEM000 opened this issue May 24, 2023 · 14 comments
Assignees
Labels
enhancement New feature or request question Further information is requested

Comments

@ASEM000
Copy link

ASEM000 commented May 24, 2023

I came across this library when I referenced your tree_at issue with Equinox as one of the motives for PyTreeClass existence. Since PyTreeClass now supports equinox or any Pytree, and adds functional/composable lenses-like setters/getters

Why don't you give it another shot and let me know if it solves your problem?

Looking at this library implementation, this looks very similar to the first version of PyTreeClass, which had some limitations; however, as jax now supports path API, I think it's better to migrate to their API.

For example:

import equinox as eqx
import pytreeclass as pytc
import jax


class Tree(eqx.Module):
    weight: jax.Array = jax.numpy.array([-1, 2, 3])
    bias: jax.Array = jax.numpy.array([1])
    counter: int = 1

    @property
    def at(self):
        return pytc.AtIndexer(self, ())


tree = Tree()

tree = (
    tree.at["counter"]
    .set(1)  # set counter to 1
    .at[jax.tree_map(lambda x: x < 0, tree)]
    .set(0)  # set negative values to 0
    .at["bias"].set(100)  # set bias to 100
)

print(tree.weight)
# [0 2 3]
print(tree.bias)
# 100
print(tree.counter)
# 1

@LouisDesdoigts
Copy link
Owner

LouisDesdoigts commented Aug 21, 2023

Hey, I really appreciate you reaching out! Yeah I remember chatting over email a while ago about this stuff when both PyTreeClass and Zodiax were in early dev. Apologies for taking so long to get to this I have had too many things on my plate lately and have wanted to engage with this properly.

This is really cool and a bit 🪄 magic 🪄 to me right now, it seems like you've solved the eqx.tree_at syntax problem like we were describing in the original Equinox issue! I've also toyed with the idea of using the Jax path API but haven't got anywhere. I really like how you can index via a tree_map, I can see some good use cases for that!

I'm curious as to how this would extend to nested classes, where pytree leaves are other pytrees, dictionaries etc. My guess looking at this example is that you would need to nest the .at[] calls? The nesting functionality is crucial to Zodiax since the focus is scientific class structures that are typically hierarchical and modular. A common problem we want is to set parameters deep in the pytree, so we use a syntax like this:

model = model.set('telescope.optics.aberrations.coefficients', new_coefficients)

or to operate on multiple specific leaves at the same time:

model = model.multiply(['leaf1', 'leaf2'], 2)

I'd love your thoughts on operations like this within the PyTreeClass framework!

@LouisDesdoigts LouisDesdoigts self-assigned this Aug 21, 2023
@LouisDesdoigts LouisDesdoigts added enhancement New feature or request question Further information is requested labels Aug 21, 2023
@LouisDesdoigts
Copy link
Owner

I also just had a look at your benchmarks against Equinox, I'm very curious as to why PyTreeClass outperforms Equinox and others! I would have thought the underlying computations would be virtually the same once compiled to XLA.

@ASEM000
Copy link
Author

ASEM000 commented Aug 22, 2023

I'm curious as to how this would extend to nested classes, where pytree leaves are other pytrees, dictionaries etc. My guess looking at this example is that you would need to nest the .at[] calls?

You can use where in AtIndexer directly, something like this:

tree = {"a": {"b": 1, "c": 2}, "d": 3}
where = ("a", "b")
print(pytc.AtIndexer(tree, where).get())
# {'a': {'b': 1, 'c': None}, 'd': None}

Here each tuple item represents a matching key for each level, so ("a", "b") means match tree["a"]["b"]

or to operate on multiple specific leaves at the same time:

You can do this by using a tuple of tuples,

tree = {"a": {"b": 1}, "c": 2, "d": 3}
where = (("a", "c"),)
print(pytc.AtIndexer(tree, where).get())
# {'a': {'b': 1}, 'c': 2, 'd': None}

Here the first item in the where tuple represents the first level, and the two-item tuple means matches any of them

For set you can use .set(..., is_leaf=...) and for multiply you can use .apply(lambda x:x*number) .

For more check here and here to extend it for custom matching strategies .

Note that, because I use tree_map_with_path, the above works with any registered nested pytrees.

@ASEM000
Copy link
Author

ASEM000 commented Aug 22, 2023

I also just had a look at your benchmarks against Equinox, I'm very curious as to why PyTreeClass outperforms Equinox and others! I would have thought the underlying computations would be virtually the same once compiled to XLA.

pytreeclass is faster because it adopts a different pattern that achieves the same functionality while using less expensive python looping that is used for each filtered transformation (inside tree_map) and inside the flattening step in the equinox module.

For example, consider the example for a pytree of m leaves with training loop ofn epochs, and let's assume we use filter_{jit,grad,vmap}, then this should result in O(2(nm)) for each filtered transformation -O(nm) for each tree_map in eqx.partition and eqx.combine- . Note that if the pytree is an eqx module, then there is more overhead from the logic and python looping defined in the module flatten step.

For the pytreeclass pattern, traversing is only done to undo the freezing/masking of tree leaves, so this will lead to a ~ O(nm). Note that by filtering on pytrees not functions, I don't need to wrap jax transformations. And by using tree_mask/tree_unmask I don't need to change the function signature or partition the tree into dynamic/static parts.

The following is the complete benchmark code I used and the results on colab cpu.

Feel free to let me know if you have any questions.

import jax
import jax.numpy as jnp
import pytreeclass as pytc
import equinox as eqx
import optax 


class PyTCLinear(pytc.TreeClass):
    def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name: str):
        self.name = name
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jax.numpy.array(0.0)

    def __call__(self, x: jax.Array):
        return x @ self.weight + self.bias


class EqxLinear(eqx.Module):
    name: str
    weight: jax.Array
    bias: jax.Array

    def __init__(self, in_dim: int, out_dim: int, key: jax.random.KeyArray, name:str):
        self.name = name
        self.weight = jax.random.normal(key, (in_dim, out_dim))
        self.bias = jax.numpy.array(0.0)

    def __call__(self, x: jax.Array):
        return x @ self.weight + self.bias



def sequential_linears(layers, x):
    *layers, last = layers
    for layer in layers:
        x = layer(x)
        x = jax.nn.relu(x)
    return last(x)


x = jnp.linspace(100, 1)[:, None]
y = x**2
key = jax.random.PRNGKey(0)
optim = optax.adam(1e-3)


@jax.value_and_grad
def pytc_loss_func(layers,x,y):
    layers = pytc.tree_unmask(layers)
    y = sequential_linears(layers, x)
    return jnp.mean((x-y)**2)

@jax.jit
def pytc_train_step(layers, optim_state, x,y):
    loss, grads = pytc_loss_func(layers,x,y)
    updates, optim_state= optim.update(grads, optim_state)
    layers = optax.apply_updates(layers, updates)
    return layers, optim_state, loss

@eqx.filter_value_and_grad
def eqx_loss_func(layers,x,y):
    y = sequential_linears(layers, x)
    return jnp.mean((x-y)**2)

@eqx.filter_jit
def eqx_train_step(layers, optim_state, x,y):
    loss, grads = eqx_loss_func(layers,x,y)
    updates, optim_state= optim.update(grads, optim_state)
    layers = eqx.apply_updates(layers, updates)
    return layers, optim_state, loss


def pytc_train(layers, optim_state, x,y, epochs=100):
    for _ in range(epochs):
        layers, optim_state, loss = pytc_train_step(layers,optim_state, x,y)
    return layers, loss

def eqx_train(layers,optim_state, x,y, epochs=100):
    for _ in range(epochs):
        layers, optim_state, loss = eqx_train_step(layers,optim_state, x,y)
    return layers, loss


for linear_count in [10,100]:
    pytc_linears = [PyTCLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
    # mask non-differentiable parameters
    pytc_linears = pytc.tree_mask(pytc_linears)
    pytc_optim_state = optim.init(pytc_linears)


    eqx_linears = [EqxLinear(1,1, key=jax.random.PRNGKey(i), name=f"linear_{i}") for i in range(linear_count)]
    eqx_optim_state = optim.init(eqx.filter(eqx_linears, eqx.is_array))


    pytc_linears , pytc_loss = pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=1000)
    eqx_linears, eqx_loss = eqx_train(eqx_linears, eqx_optim_state, x,y, epochs=1000)

    assert pytc_loss == eqx_loss

    time_pytc = %timeit -o pytc_train(pytc_linears, pytc_optim_state, x,y, epochs=100)
    time_eqx = %timeit -o eqx_train(eqx_linears, eqx_optim_state, x,y, epochs=100)
    print(f"Eqx/PyTc: {time_eqx.average/time_pytc.average} for {linear_count} layers")
WARNING:jax._src.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
35.3 ms ± 9.48 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
186 ms ± 52.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Eqx/PyTc: 5.290240573795192 for 10 layers
698 ms ± 127 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
1.74 s ± 231 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
Eqx/PyTc: 2.495322682506591 for 100 layers

@LouisDesdoigts
Copy link
Owner

Thanks for the quick response! Yeah it looks like the pytc.AtIndexer has a very similar syntax to Zodiax 'paths'. Does your AtIndexer call __getattr__ under the hood? Due to the highly nested nature of the kind of classes I work with I commonly raise attributes via dictionary keys etc through my classes so I can access attributes like optics.layers.aberrations -> optics.aberrations. I assume this would similarly work through the AtIndexer?

Interesting, thanks for this example! So if I've understood correctly you've made accessing attributes requires less tree operations, resulting in the speed gains not from the underlying computation but more efficient operation on the pytree itself? Would this result in a generic speed increase for updating pytrees generally? This has been a point in the back of my mind for a while, since in my work its often necessary to update singular values large pytrees, resulting in slow-downs from loading the pytree into a new instance. I've always thought my looping/nesting implantation of accessing leaves could be improved.

Another thing that do that breaks from the more general ML use cases is take gradients with respect to either a single, or small subset of parameters (since we are modelling physical objects, we know that things like the diameter of our telescope isn't going to change 😝). This used to be a default(ish) behaviour of Equinox's filter_grad etc, but it was removed and I had to implement a eqx.partition and eqx.combine, see here. From what you're saying it sounds like this might be able to be improved? I assume to achieve the same thing in PyTreeClass you would use a similar pattern with tree_mask and tree_unmask, can you speak to efficiency of PyTreeClass vs Equinox here?

I also see that you've implemented a equivalent of eqx.field(static=True) using the tree_mask and tree_unmask inside the functions themselves. I'm wondering if there would be a way to gain a similar syntax (specified in class definitions) without needing to modify the functions operating on the pytrees? Furthermore, is it possible using with a syntax like this to update these 'static' (typically non-jax type) parameters during calculations? I've been toying with the idea of implementing a 'unit' string parameter to my classes so that users can work in units that keep them away from float precision error (optical sciences work work with wavelength ranges from 1e-9 : 1e9), so being able to control units with a static non-jax type without requiring users to apply extra transformations in their functions has been a long-term goal.

One of the primary goals of Zodiax is to make working with the objects require as little knowledge as possible. As my target is a science audience most of the users have little to no programming background and are mostly self-taught (no wonder why most scientific code has quirks, mine included 🙃). That was a large motivator for Zodiax, to hide users from needing to learn about lambda functions, tree_maping, filtering etc. This is why all operations in Zodiax have a pure 'path based' interface, granted there is likely a loss in computational efficiency, but the gain in being able to get users up to speed with as little knowledge as possible is crucial. I'm definitely interested in some of the gains by using PyTreeClass under the hood as it seems like you've addressed quite a few things I've been thinking about!

Thanks again for these responses, and sorry for hammering you with so many questions, it seems like we're thinking about the same stuff!

@ASEM000
Copy link
Author

ASEM000 commented Aug 22, 2023

Yeah it looks like the pytc.AtIndexer has a very similar syntax to Zodiax 'paths'. Does your AtIndexer call getattr under the hood?

No, AtIndexer deals with path entries defined using jax.tree_map.tree_flatten_with_path; this enables integer/string / user-defined matching strategies in addition to boolean pytrees. the file defines them is here. The path entry is converted by dispatching to a class with __eq__ defined for use in jax.tree_util.tree_map_with_path matching. This yields a mask that is used later to get/set/apply on leaves.

Interesting, thanks for this example! So if I've understood correctly you've made accessing attributes requires less tree operations, resulting in the speed gains not from the underlying computation but more efficient operation on the pytree itself?

If you mean the speed up with equinox, then this is purely due to the design of TreeClass and traversing only when needed. I think it's ok to trade speed with convenience in non-expensive paths in the program - e.g. jitted train step-

Another thing that does that breaks from the more general ML use cases is take gradients with respect to either a single, or small subset of parameters (since we are modelling physical objects, we know that things like the diameter of our telescope isn't going to change 😝).

see here

I also see that you've implemented a equivalent of eqx.field(static=True) using the tree_mask and tree_unmask inside the functions themselves.

There is a meaningful difference between the two for your use case:

  • First eqx.field(static=True) is defined as part of the class definition, tree_mask wraps a leaf with a wrapper that yields no leaves when flattened. It is not possible to set a field static after class creation, but you can mask a leaf with a frozen wrapper at any time [here(https://pytreeclass.readthedocs.io/en/latest/notebooks/dealing_with_non_jaxtype.html#). Similarly, you can not undo a static field after class creation, but you can undo masking a leaf with tree_umask.
    static field makes sense if you don't assume a masking/filtering mechanism like in flax's PytreeNode class.

  • Second, eqx.field handles hashable items. Using static=True on a jax array is considered a bug, but masking an array with tree_mask will wrap the array with a frozen wrapper that can handle arrays

  • Third, From the design POV, the static field contradicts the mental model of equinox itself and pytreeclass, namely static field moves the burden of handling non-jaxtype (e.g. strings) to the user, instead of being handled by tree_mask or filtered transformation.

  • Fourth, performance, for a tree of m>>1 leaf with a single static field, the equinox module flatten step will involve a loop over m leaves with logic to eliminate the static field. For pytreeclass TreeClass, this is not happening.

I'm wondering if there would be a way to gain a similar syntax (specified in class definitions) without needing to modify the functions operating on the pytrees? Furthermore, is it possible using with a syntax like this to update these 'static' (typically non-jax type) parameters during calculations? I've been toying with the idea of implementing a 'unit' string parameter to my classes so that users can work in units that keep them away from float precision error (optical sciences work work with wavelength ranges from 1e-9 : 1e9), so being able to control units with a static non-jax type without requiring users to apply extra transformations in their functions has been a long-term goal.

I think this is XY problem. Can you give me an MWE of your desired functionality?

@LouisDesdoigts
Copy link
Owner

Ah yes I suspected that using the Jax KeyPaths might end up breaking the parameter raising, which is a bit of an issue for our use case. I wonder if its possible to build the correct path under the hood by building it from the accessors directly. This will incur some overhead though, so not an ideal solution. What are your thoughts on this?

Yeah it seems like overall PyTreeClass operates similarly to Equinox in regards to freezing parameters. The behavior in both is somewhat 'inverse' to our desired functionality. Rather than freezing specified leaves, it makes more sense for us to only take gradients with respect to specific leaves. I was able to make a small wrapper that achieves this functionality using Equinox, but it seems like it would be more complex for nested classes in PyTreeClass.

Using static=True on a jax array is considered a bug, but masking an array with tree_mask will wrap the array with a frozen wrapper that can handle arrays

Does this mean that Jax will understand these array leaves as 'static' under jit transformations? They obviously can't be treated as actually static since they aren't hashable, but does this lead to a meaningful difference under the hood? We have always though that speeds might be improved by preventing Jax tracing these arrays that can be considered 'static'.

What are your thoughts on a filtered-transformation equivalent for jit and grad with a where mask that automatically masks out every leaf not in the where function? I've built an equivalent using Equinox here. Do you suspect there would be any meaningful difference under the hood once compiled to XLA, as opposed to the partition -> combine pattern used in Equinox?

I think this is XY problem. Can you give me an MWE of your desired functionality?

Yes it more or less is. This is not necessarily needed functionality, but more an idea I've been toying with as a way to define and update some units parameter of leaves or pytrees. Here's a MWE

from jax import Array
import pytreeclass as pytc

class Foo(pytc.TreeClass):
    leaf : Array
    unit : str = pytc.static_leaf() # Something similar to this

    def __call__(self, new_unit: str):
        new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function
        return self.at[('leaf', 'unit')].set((new_value, new_unit))

The idea here is that since the leaves are already treated as 'static' you avoid the need to mask/unmask in functions.

The other option would be a filtered-transform type operation:

from jax import Array
import pytreeclass as pytc

class Foo(pytc.TreeClass):
    leaf : Array
    unit : str

    def __call__(self, new_unit: str):
        new_value = convert(self.leaf, self.unit, new_unit) # some simple unit conversion function
        return self.at[('leaf', 'unit')].set((new_value, new_unit))

param = 'leaf'
@filter_jit(param)
@filter_grad(param)
def f(pytree):
    return pytree('nm').leaf**2

The idea here is that the filtered transformation automatically masks all leaves except for the leaf parameter, by building the mask programmatically from the param input. This pattern is actually more ideal than the first example since it alleviates the need for users to think about it in the first place and relaxes the need to be more considered about input array dtypes.

Generally, the goal is to have 'mask-by-path' behavior as this is more in-line with the goals of scientific models.

@ASEM000
Copy link
Author

ASEM000 commented Aug 25, 2023

Yeah it seems like overall PyTreeClass operates similarly to Equinox in regards to freezing parameters. The behavior in both is somewhat 'inverse' to our desired functionality. Rather than freezing specified leaves, it makes more sense for us to only take gradients with respect to specific leaves.

Its easier with pytreeclass,

import pytreeclass as pytc
import jax.numpy as jnp
import jax


class Tree(pytc.TreeClass):
    def __init__(self):
        self.weight = jnp.array([1.0])
        self.bias = jnp.array([1.0])
        self.name = "tree"  # non-jax type


tree = Tree()

# mask all leaves except weight
mask = tree.at[...].set(True).at["weight"].set(False)
masked_tree = pytc.tree_mask(tree, mask)


@jax.jit
@jax.grad
def loss_func(tree, x: jax.Array):
    tree = pytc.tree_unmask(tree)
    return jnp.mean(tree.weight * x**2)


grads = loss_func(masked_tree, jnp.array([2.0]))
# Tree(weight=[4.], bias=#[1.], name=#tree)

new_tree = jax.tree_map(lambda x, g: x + g, masked_tree, grads)
new_tree = pytc.tree_unmask(new_tree)
# Tree(weight=[5.], bias=[1.], name=tree)

Does this mean that Jax will understand these array leaves as 'static' under jit transformations?

Yes, freeze or tree_mask wrap arrays with a wrapper that defines a hash rule, something similar to this

What are your thoughts on a filtered-transformation equivalent for jit and grad with a where mask that automatically masks out every leaf not in the where function? I've built an equivalent using Equinox here.

Its harder using equinox style because you have to 1) filter pytrees to work with libraries like optax and then you have to 2) filter over functions. This is a common usability issue in equinox.
From your code, this means that you have to supply the boolean filter explictly to eqx.filter before you supply it to optax, i.e. you are repeating the process twice, once explicitly and other implicitly.

In contrast, using the above code sample, the masked_tree can be supplied to optax with no further operations.

The idea here is that since the leaves are already treated as 'static' you avoid the need to mask/unmask in functions.

If you want to take the burden of defining static nodes inside the class definition, then the following pattern works :

import jax
import pytreeclass as pytc

# changing a "frozen" leaf value without tree_mask/tree_unmask on user side

class Foo(pytc.TreeClass):
    def __init__(self, leaf: jax.Array, unit: str):
        self.leaf = leaf
        self.unit = pytc.freeze(unit)  # non-jax type

    def change_unit(self, new_unit: str) -> "Foo":
        return self.at["unit"].set(pytc.freeze(new_unit), is_leaf=pytc.is_frozen)


foo = Foo(jax.numpy.array([1.0]), "m")
# Foo(leaf=f32[1](μ=1.00, σ=0.00, ∈[1.00,1.00]), unit=#m)
new_foo = foo.change_unit("cm")
# Foo(leaf=f32[1](μ=1.00, σ=0.00, ∈[1.00,1.00]), unit=#cm)

Your MWE is a strong case for using pytreeclass over equinox, because in equinox if you define unit as static field, its not possible to change it afterward immutably.

Please let me know what you think.

@LouisDesdoigts
Copy link
Owner

Okay so I've had some time to go through the PyTreeClass docs and... wow you have implemented some seriously awesome functionality! Automatic data type and shape checking, bcmap, default frozen leaves, these all could be very useful in my downstream software. I'm very seriously considering migrating Zodiax to a PyTreeClass backend, so I hope my questions aren't too annoying 😝.

I can envision a better version of Zodiax that circumvents the need for wrapping specific Equinox function, using a single default filter function wrapper that automatically masks and unmasks pytrees before being passed to the original function. Ideally this could be used to make all pytree inputs 'safe' for Jax functions, with syntax like this:

params = ['param1', 'param2']

@jax.jit
@jax.grad
@zdx.filter(params)
def loss(pytree):
    return pytree()**2

Ideally this would result in gradients only with respect to the params list, with all others parameters treated as static under jit. I had a quick stab at implementing this but couldn't quite get it to work:

import jax
import jax.numpy as np
from functools import wraps 
import pytreeclass as pytc
import equinox as eqx
import zodiax as zdx

# Functional wrapper to build the mask using __getattr__, not super important right now
def set(pytree, parameters, values):
    if values is None:
        values = [None]
        if isinstance(parameters, str):
            parameters = [parameters]
    new_parameters, new_values = zdx.base._format(parameters, values)
    leaves_fn = lambda pytree: zdx.base._get_leaves(pytree, new_parameters)
    return eqx.tree_at(leaves_fn, pytree, new_values, is_leaf = lambda leaf: leaf is None)

# Simple class
class Foo(pytc.TreeClass):
    x : None
    y : None
    
    def __init__(self, x, y):
        self.x = np.asarray(x, float)
        self.y = np.asarray(y, float)
    
    def __call__(self):
        return self.x * self.y


foo = Foo(1, 2)
foo

>>> Foo(x=f32[](μ=1.00, σ=0.00, ∈[1.00,1.00]), y=f32[](μ=2.00, σ=0.00, ∈[2.00,2.00]))

And now some code to automate the masking/unmasking

def filter(params):
    def decorator(func):
        def wrapper(pytree, *args, **kwargs):

            # Generate Mask
            false_mask = pytree.at[...].set(True)
            bool_mask = set(false_mask, params, False)
            masked_tree = pytc.tree_mask(pytree, bool_mask)

            # Generate unmasked func
            def unmask(masked):
                unmasked = pytc.tree_unmask(masked, bool_mask)
                return func(unmasked, *args, **kwargs)

            # Mask and call wrapped function
            return unmask(masked_tree)

        return wrapper
    return decorator

params = 'x'

@jax.jit
@jax.value_and_grad
@filter(params)
def loss_fn(pytree):
    return np.abs(pytree())**2

loss, grads = loss_fn(foo)

print(loss)
print(grads)

>>> 4.0
>>> Foo(x=8.0, y=4.0)

Which is still taking gradients with respect to all parameters. I assume this is simply to do with my construction of the wrapper function? Although it seems like this construction might not work either way, since the parameter passed into the jax transformations will always be the unmasked pytree, with the masking and unmasking happening downstream. I would prefer not to need to write specific wrappers to individual transformations like jit, grad etc. Although it may be possible to write a single wrapper that applies to all jax functions as a good compromise.

What are your thoughts on a generalized filter function that can wrap any pytree input based on a mask generated by a list of parameters?


Another question I have is when the callbacks of pytc.field are invoked? From the provided examples only show it being called during the constructor. I did some testing and this seems to be the case. It would be super handy if these callbacks could be invoked during calls to at[where].set(value), since its always been a bit of a 'loose end' that leaves can be set to anything of any dtype. It would be great to build classes with strict conditions on leaves, what are your thoughts on this?

@ASEM000
Copy link
Author

ASEM000 commented Aug 29, 2023

Okay so I've had some time to go through the PyTreeClass docs and... wow you have implemented some seriously awesome functionality! Automatic data type and shape checking, bcmap, default frozen leaves, these all could be very useful in my downstream software. I'm very seriously considering migrating Zodiax to a PyTreeClass backend, so I hope my questions aren't too annoying 😝.

I am glad you like it :)

I can envision a better version of Zodiax that circumvents the need for wrapping specific Equinox function, using a single default filter function wrapper that automatically masks and unmasks pytrees before being passed to the original function. Ideally this could be used to make all pytree inputs 'safe' for Jax functions, with syntax like this:

I think my point will be clear if you write a MWE using zodiax with optax to update some leaves and ignore the rest?

Another question I have is when the callbacks of pytc.field are invoked? From the provided examples only show it being called during the constructor. I did some testing and this seems to be the case. It would be super handy if these callbacks could be invoked during calls to at[where].set(value), since its always been a bit of a 'loose end' that leaves can be set to anything of any dtype. It would be great to build classes with strict conditions on leaves, what are your thoughts on this?

field returns a simple descriptor, and the callbacks are invoked on the descriptor's __set__ method. This descriptor can be used with any python class to apply callbacks on setattr. Since .at[...].set() use tree_map, then I have to redefine the unflatten rule to instantiate the class. However, this has a major downside, unlike frozen dataclass- based pytrees like flax.struct.PyTreeNode or equinox.Module, leaves can be added/removed to TreeClassafter instantiation functionally, implementing this feature will mean I can no longer add/remove leaves.

@LouisDesdoigts
Copy link
Owner

I think my point will be clear if you write a MWE using zodiax with optax to update some leaves and ignore the rest?

Yeah so my 'ideal' functionality would be something like this

import jax
import jax.numpy as np
import pytreeclass as pytc
import zodiax as zdx
import optax

# Simple class
class Foo(pytc.TreeClass):
    x : float
    y : float
    
    def __init__(self, x, y):
        self.x = np.asarray(x, float)
        self.y = np.asarray(y, float)
    
    def model(self):
        return self.x * self.y

# Construct object
tree = Foo(1, 2)

# Define parameters to optimise
params = 'x'

# Use zodiax to automate mapping of optimisers to leaves
optim, opt_state = zdx.get_optimiser(tree, params, optax.adam(1))

# Define loss and apply a 'filter' to handle the masking/unmasking
@jax.jit
@jax.value_and_grad
@zdx.filter(params)
def loss_fn(pytree):
    return np.abs(pytree.model())**2

loss, grads = loss_fn(tree)

# apply the update
updates, opt_state = optim.update(grads, opt_state)
updated_tree = zdx.apply_updates(tree, updates)

I can automate the mapping of optax optimisers to leaves so don't worry about that, the main thing is the zdx.filter(params). The idea is that it would transform the pytree input of the function it is decorating, automatically masking and unmasking such that all leaves not specific by params are 'masked' out of the pytree so that it can be passed to jax transformations.

From my testing of this it seems like that may not be possible since the masking and unmasking need to wrap around the calls of jax.jit, jax.value_and_grad. I could accept a syntax similar to equinox (and zodiax) at present:

@zdx.filter_jit(params)
@zdx.filter_value_and_grad(params)
def loss_fn(pytree):
    return np.abs(pytree.model())**2

But the former example is definitely more ideal since the masking/unmasking would only be done once. Do you think something like that is possible?


field returns a simple descriptor, and the callbacks are invoked on the descriptor's set method. This descriptor can be used with any python class to apply callbacks on setattr. Since .at[...].set() use tree_map, then I have to redefine the unflatten rule to instantiate the class. However, this has a major downside, unlike frozen dataclass- based pytrees like flax.struct.PyTreeNode or equinox.Module, leaves can be added/removed to TreeClassafter instantiation functionally, implementing this feature will mean I can no longer add/remove leaves.

Yeah I definitely like the idea of being able to add and remove leaves, however I think there wouldn't be too many use-cases for Zodiax classes. It might even be worth having two base classes for Zodiax, a 'frozen' and 'unfrozen' one that users can then choose which they prefer.

If I was to implement this 'frozen' version, can you speak as to how it would integrate with the rest of PyTreeClasss? If everything still works I would consider making something like this.

@ASEM000
Copy link
Author

ASEM000 commented Aug 30, 2023

For this use case, AFAIK I think you can do it - IMO- in a simpler way:

Method 1: using optax masked

import pytreeclass as pytc
import jax.numpy as jnp
import optax

@pytc.autoinit
class Tree(pytc.TreeClass):
    a: float = jnp.array(1.0)
    b: float = jnp.array(2.0)
    name: str = "tree"  # non-jax type for demonstration

    def model(self):
        return self.a + self.b

tree = pytc.tree_mask(Tree())

optim = optax.chain(
    # update only "a"
    optax.masked(optax.sgd(2.0), mask=tree.at[...].set(False).at["a"].set(True)),
    # no updates to the rest of the tree
    optax.masked(optax.sgd(0.0), mask=tree.at[...].set(True).at["a"].set(False)),
)

optim_state = optim.init(tree)

def loss_func(tree):
    ypred = tree.model()
    return jnp.mean((ypred - y) ** 2)

@jax.jit
def train_step(tree, optim_state):
    grads = jax.grad(loss_func)(tree)
    updates, optim_state = optim.update(grads, optim_state)
    tree = optax.apply_updates(tree, updates)
    return tree, optim_state

def train(tree, optim_state, num_steps):
    for _ in range(num_steps):
        tree, optim_state = train_step(tree, optim_state)
    return tree, optim_state
tree, optim_state = train(tree, optim_state, num_steps=10)
# tree = pytc.tree_unmask(tree)
print(tree)
# Tree(a=148686.31, b=2.0, name=tree)

%timeit train(tree, optim_state, num_steps=1000)
# 9.2 ms ± 134 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Some notes:

  • I did not use tree_unmask inside the loss_func, because your function call did not use non-jaxtype leaves (e.g. activation function)
  • Gradients were taken with respect to b, but no update was done because we use optax.masked with sgd(0.)

Method 2

import pytreeclass as pytc
import jax.numpy as jnp
import optax

@pytc.autoinit
class Tree(pytc.TreeClass):
    a: float = jnp.array(1.0)
    b: float = jnp.array(2.0)
    name: str = "tree"  # non-jax type for demonstration

    def model(self):
        self = pytc.tree_unmask(self)
        return self.a + self.b

optim = optax.sgd(2.0)
tree = pytc.tree_mask(Tree(), mask=Tree(a=False, b=True, name=True))
optim_state = optim.init(tree)


def loss_func(tree):
    ypred = tree.model()
    return jnp.mean((ypred - y) ** 2)

@jax.jit
def train_step(tree, optim_state):
    grads = jax.grad(loss_func)(tree)
    updates, optim_state = optim.update(grads, optim_state)
    tree = optax.apply_updates(tree, updates)
    return tree, optim_state

def train(tree, optim_state, num_steps):
    for _ in range(num_steps):
        tree, optim_state = train_step(tree, optim_state)
    return tree, optim_state


tree, optim_state = train(tree, optim_state, num_steps=10)

print(tree)
# Tree(a=148686.31, b=#2.0, name=#tree)

%timeit train(tree, optim_state, num_steps=1000)
# 5.91 ms ± 121 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Note:

  • The gradient with b was not taken because it was frozen.
  • Did not need optax.masked
  • Needed tree_unmask because b was frozen, but I moved tree_unmask to the model method as a variation e.g. if you don't want it in the loss function.

So, I wanted to point out that with these two examples the following:

  • If you use optax.masked, then no need for filter(params)
  • If you don't use optax.masked, then you will need filter(params)

From the zodiax code, I see you are using optax, so I would recommend using method 1 for your use case. I personally use method 1 for my use case (e.g. min-max etc, ...)

The idea is that it would transform the pytree input of the function it is decorating, automatically masking and unmasking such that all leaves not specific by params are 'masked' out of the pytree so that it can be passed to jax transformations.

masking can be used to hide/freeze a leaf, This can be useful to either: 1) pass non-jaxtype or 2) not updating leaf value. optax can achieve 2) but not 1) , however, I would recommend using optax for 2) -like method 1- and use masking for 1).


If I was to implement this 'frozen' version, can you speak as to how it would integrate with the rest of PyTreeClasss? If everything still works I would consider making something like this.

TreeClass is frozen in the same sense as dataclasses -i.e. can not change attributes mutably-,My point was that leaves in frozen dataclass based pytrees are class variables - i.e. can not be changed on instance level-

But you can change these attributes immutably in TreeClass because leaves are the instance attributes.
like the following:

import pytreeclass as pytc
@pytc.autoinit
class Tree(pytc.TreeClass):
    a: float = 1.0
    def add_one(self):
        self.a += 1.0

tree= Tree()
tree.add_one()  # AttributeError: Cannot set attribute value=2.0 to `key='a'` ...
_, new_tree = tree.at["add_one"]()
print(new_tree)
# Tree(a=2.0)

Note that new_tree and tree have different leaf values but the same methods; more importantly, no in-place mutation happened.

@ASEM000
Copy link
Author

ASEM000 commented Sep 6, 2023

Check the new release; the new features are inspired by our discussion here.

@LouisDesdoigts
Copy link
Owner

Wow these are some awesome new features, you're a legend! I've done some testing of using PyTreeClass on the back-end, but I've been busy with a re-factor of a downstream library of time. I probably wont have time to put towards this in the near future, but there's definitely features I will integrate soon!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants