Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

jax.lax.scan doesnt work with namedtuples as input argument #3826

Closed
isaacgerg opened this issue Jul 23, 2020 · 4 comments
Closed

jax.lax.scan doesnt work with namedtuples as input argument #3826

isaacgerg opened this issue Jul 23, 2020 · 4 comments

Comments

@isaacgerg
Copy link

isaacgerg commented Jul 23, 2020

        def fun(carry, x):
            y = jnp.abs(residual_function(x,param1, param2))
            return carry, y
        error = jax.lax.scan(fun, 0, residuals)

where residuals is an array of namedtuple. This gives a "builtins.IndexError: tuple index out of range" error in lax_control_flow.py:1173 (lengths = [x.shape[0] for x in xs_flat]) because residuals is flattened before "iterated" over in the scan. For example, given a namedtuple, Residual, which has 3 items and x which is [Residual(p1, p2), Residual(p3,4)], the flatten will give a list which is 6 elements long instead of 2.

Documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) shows pseudocode doing "for x in xs" but xs is flattened before the iteration begins

@isaacgerg isaacgerg changed the title jax.lax.scan doesnt work with named tuples Delete this. Jul 23, 2020
@isaacgerg isaacgerg changed the title Delete this. jax.lax.scan doesnt work with named tuples Jul 23, 2020
@isaacgerg isaacgerg reopened this Jul 23, 2020
@isaacgerg
Copy link
Author

Updated description. Let me know if more clarity is needed.

@isaacgerg isaacgerg changed the title jax.lax.scan doesnt work with named tuples jax.lax.scan doesnt work with namedtuples as input argument Jul 23, 2020
@mattjj
Copy link
Collaborator

mattjj commented Jul 23, 2020

Thanks for the question! Can you provide a minimal reproducer of the issue?

where residuals is an array of namedtuple

JAX doesn't support working with object arrays, and an array of namedtuples would be an object array (i.e. ndarray with dtype object).

Also, lax.scan scans over leading axes of arrays, not of lists. That is, in the "for x in xs" pseudocode, it's important that xs is an array!

I'm guessing that things are working as intended, except we should raise a much clearer error message here.

@isaacgerg
Copy link
Author

The last 2 sentences answer my question. It appears this is not a bug

The model I am optimizing has 13 parameters, 2 of which are free. I originally wrote the code in Ceres making each residual a class instantiation which is a common pattern in Ceres. However, this pattern is a essentially a array of structures (AOS) which doesn't seem to be supported well by Jax. In all of the examples I have seen, the data has to be structures of arrays (SOA) because Jax operates on array.

@hawkinsp
Copy link
Collaborator

Yes, that's right, JAX really only works with a structures-of-arrays approach.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants