Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

namedtuple support in arguments to transformed functions #446

Closed
jonasrauber opened this issue Feb 25, 2019 · 6 comments · Fixed by #736
Closed

namedtuple support in arguments to transformed functions #446

jonasrauber opened this issue Feb 25, 2019 · 6 comments · Fixed by #736
Assignees
Labels
enhancement New feature or request

Comments

@jonasrauber
Copy link
Contributor

It would be great if xla.abstractify would also accept namedtuples. Loop state's can consist of quite a lot of values and organizing them in a namedtuple rather than a tuple would make things nicer.

@mattjj mattjj added the enhancement New feature or request label Feb 25, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 25, 2019

There's actually a convenient way to add support for custom container types throughout JAX, not just in loop carries but also for grad, jit, vmap, etc, all at once. Of course it's not documented at all... :)

You can register a custom type as a "pytree" (tree-like Python container) like this:

from collections import namedtuple
from jax.tree_util import register_pytree_node
from jax import grad, jit
import jax.numpy as np

Point = namedtuple("Point", ["x", "y"])

register_pytree_node(
    Point,
    lambda xs: (tuple(xs), None),  # tell JAX how to unpack to an iterable
    lambda _, xs: Point(*xs)       # tell JAX how to pack back into a Point
)


def f(pt):
  return np.sqrt(pt.x**2 + pt.y**2)

pt = Point(1., 2.)

print f(pt)        # 2.236068
print grad(f)(pt)  # Point(x=..., y=...)

g = jit(f)
print g(pt)  # 2.236068

So that's an easy and general way to get your code working now. It also means you can have your namedtuple classes contain nested tuples/lists/dicts, or have them nested in other tuples/lists/dicts.

(By the way, the extra data that can be returned by the to-iterable function and consumed by the to-pytree fun is for things like dict keys. In the above example, we're just returning None when mapping to an iterable and then ignoring it when reconstructing.)

However, we should consider making JAX work with all namedtuple classes by default, without having to register them. Any thoughts on that, or objections to it?

@mattjj mattjj self-assigned this Feb 25, 2019
@mattjj mattjj changed the title namedtuple support in xla.abstractify namedtuple support in arguments to transformed functions Feb 25, 2019
@mattjj
Copy link
Collaborator

mattjj commented Feb 25, 2019

I revised the issue title because we'd handle the issue in api.py and xla.abstractify would never need to see these types (just like it never sees tuples/lists/dicts).

@jonasrauber
Copy link
Contributor Author

Ha, that's awesome! Regarding namedtuple support: Given that namedtuple's are real subclasses of tuples, I think supporting all namedtuples out of the box would be the most intuitive solution.

@rsepassi
Copy link
Contributor

rsepassi commented Mar 8, 2019

+1 to having JAX work with all namedtuple classes

@zhongwen
Copy link
Contributor

+1 Our existing codebase has been heavily relying on namedtuple and it would be great to support it in JAX.

@mattjj
Copy link
Collaborator

mattjj commented May 20, 2019

#736 made namedtuple classes transparent by default. Let us know if you have any issues with it!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants