-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
jax.lax.scan doesnt work with namedtuples as input argument #3826
Comments
Updated description. Let me know if more clarity is needed. |
Thanks for the question! Can you provide a minimal reproducer of the issue?
JAX doesn't support working with object arrays, and an array of namedtuples would be an object array (i.e. ndarray with dtype object). Also, I'm guessing that things are working as intended, except we should raise a much clearer error message here. |
The last 2 sentences answer my question. It appears this is not a bug The model I am optimizing has 13 parameters, 2 of which are free. I originally wrote the code in Ceres making each residual a class instantiation which is a common pattern in Ceres. However, this pattern is a essentially a array of structures (AOS) which doesn't seem to be supported well by Jax. In all of the examples I have seen, the data has to be structures of arrays (SOA) because Jax operates on array. |
Yes, that's right, JAX really only works with a structures-of-arrays approach. |
where residuals is an array of namedtuple. This gives a "builtins.IndexError: tuple index out of range" error in lax_control_flow.py:1173 (lengths = [x.shape[0] for x in xs_flat]) because residuals is flattened before "iterated" over in the scan. For example, given a namedtuple, Residual, which has 3 items and x which is [Residual(p1, p2), Residual(p3,4)], the flatten will give a list which is 6 elements long instead of 2.
Documentation (https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.scan.html) shows pseudocode doing "for x in xs" but xs is flattened before the iteration begins
The text was updated successfully, but these errors were encountered: