-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
namedtuple support in arguments to transformed functions #446
Comments
There's actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for You can register a custom type as a "pytree" (tree-like Python container) like this: from collections import namedtuple
from jax.tree_util import register_pytree_node
from jax import grad, jit
import jax.numpy as np
Point = namedtuple("Point", ["x", "y"])
register_pytree_node(
Point,
lambda xs: (tuple(xs), None), # tell JAX how to unpack to an iterable
lambda _, xs: Point(*xs) # tell JAX how to pack back into a Point
)
def f(pt):
return np.sqrt(pt.x**2 + pt.y**2)
pt = Point(1., 2.)
print f(pt) # 2.236068
print grad(f)(pt) # Point(x=..., y=...)
g = jit(f)
print g(pt) # 2.236068 So that's an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts. (By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we're just returning None when mapping to an iterable and then ignoring it when reconstructing.) However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it? |
I revised the issue title because we'd handle the issue in api.py and xla.abstractify would never need to see these types (just like it never sees tuples/lists/dicts). |
Ha, that's awesome! Regarding namedtuple support: Given that namedtuple's are real subclasses of tuples, I think supporting all namedtuples out of the box would be the most intuitive solution. |
+1 to having JAX work with all namedtuple classes |
+1 Our existing codebase has been heavily relying on namedtuple and it would be great to support it in JAX. |
#736 made namedtuple classes transparent by default. Let us know if you have any issues with it! |
It would be great if
xla.abstractify
would also accept namedtuples. Loop state's can consist of quite a lot of values and organizing them in a namedtuple rather than a tuple would make things nicer.The text was updated successfully, but these errors were encountered: