-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add support for computing the cumulative sum to the standard #597
Comments
There's also |
Based on previous experience with complaints about bad naming (see, e.g., scipy/scipy#12924 and https://www.reddit.com/r/programminghorror/comments/j6sd61/i_was_just_looking_at_the_documentation_for/) I would very much prefer not to enshrine This is pretty niche functionality and arguably not "core" enough to implementing an array library for it to be in this standard at all, so I'd vote for leaving it out completely. In a compat layer it could perhaps be named |
Niche it may be, I still think p = jnp.zeros(10, int).at[jnp.array([1, 4, 8])].set(1)
# [0 1 0 0 1 0 0 0 1 0]
s = p.cumsum()
# [0 1 1 1 2 2 2 2 3 3] This is especially true when one works with statically-shaped system like JAX where these tricks are more or less required. Its usages in the implementation of |
@soraros it seems like all that usage of # For an array `xs` and a boolean index `conds` of the same shape
# example JAX expression for `jax_filter`:
>>> xs = jnp.arange(5)
>>> conds = jnp.array([True, False, True, False, True])
>>> cumsum = jnp.cumsum(conds)
>>> cumsum
Array([1, 1, 2, 2, 3], dtype=int32)
>>> jnp.zeros_like(xs).at[cumsum - 1].add(jnp.where(conds, xs, 0))
Array([0, 2, 4, 0, 0], dtype=int32)
# NumPy:
>>> xs = np.arange(5)
>>> conds = np.array([True, False, True, False, True])
>>> xs[conds]
array([0, 2, 4])
>>> np.zeros_like(xs) + np.where(conds, xs, 0) # note: not identical to JAXs `.at`, 0-padding not at the end
array([0, 0, 2, 0, 4]) That kind of JAX code looks very bad, and as a motivator for writing portable code based on a standard I don't think it's a positive. JAX could add support for There's a known blocker (brought up by the JAX team before) for adding more in-place operator support beyond what they have now, which is that this standard should be able to better guarantee that |
@rgommers Yes and no. It is a cumbersome workaround, but I also think the problem is more fundamental than that. JAX is essentially a front-end for XLA, and the primitives provided by XLA (for now) require static shape. So the line that actually go wrong is >>> xs[ix_bool]
array([0, 2, 4]) Note this code does work in JAX, though not jittable, for we don't know its output shape. Let's pretend So what we really work around is the static shape requirement (recall the need of a Now, for something a bit off-topic.I think the JAX style functional syntax
a = zeros(m) # initialing a
a[I] += arange(n) # semantically, still initialing a
# VS
# being concise here is not the important point
# this line becomes a "semantical block" for initialisation
a = zeros(m).at[I].add(arange(n)) # initialing a
# I think these are fairly cumbersome to represent in `numpy`, as we don't have kwargs for __getitem__
b = a.at[I].add(val, unique_indices=True) # important info for accelerators
c = b.at[J].get(mode='fill', fill_value=nan) # sure, we have `take`, but this is uniform and cool |
@soraros thanks for all this detail, it's very interesting actually. I think there's something to be said indeed for |
@rgommers Glad you find the exchange interesting! |
How do you implement cumulative sum using only array API functions (and without using a Python loop)? |
One inefficient possibility:
|
cumsum
to the standard
Done now in gh-609 - sorry for the delay! I spent a lot of time refreshing my memory and trying to put together something more coherent. But it's tricky; the need for new API in JAX to avoid |
I think cumulative sum is a rather fundamental array operation and we should include it in the standard. None of the typical reasons for omitting a function from the API standard apply here:
To give a few other examples of how I've used it:
As a reference point on popularity, I see about twice as many uses of I'll also second Ralf's point on calling it |
Before locking in the specification, it would be worthwhile taking a look at numpy/numpy#6044, and numpy/numpy#14542 that is referenced from numpy/numpy#6044. My comment in numpy/numpy#14542 provides some evidence for the usefulness of allowing the result to include a prepended 0. (More generally, for other cumulative operations such as |
I agree, starting with zero (and excluding the last value) is quite useful, and I would definitely support adding an optional argument. Is it clear that this is a better default behavior? My inclination is that it would not be worth the trouble to break existing code. |
I agree, starting with zero is not a better default behavior, if only because of the long history of the existing behavior in numpy. The links to those previous discussions are more for increasing awareness of the interest and usefulness of this option. That is, don't lock in an API that can't be enhanced with an option later. I don't think that is a problem here, but I don't know how hard it will be to add options to a function in the array API specification later. Also, maybe the reminder will inspire someone to push forward with the enhancement in numpy 🤞 or in other libraries. |
Good points on the initial value feature. That doesn't look all that hard to push forward in numpy, there doesn't seem to be a blocker other than no one having done the work. |
Just opinion:
|
Looks like there's enough thumbs-ups for cumulative sum at least, and we should add it now. @kgryte just volunteered to open a PR for it. |
On the PR there is currently |
On the NumPy issue numpy/numpy#6044 |
I've updated the PR to use |
This RFC requests to include a new API in the array API specification for the purpose of computing the cumulative sum.
Overview
Based on array comparison data, the API is available in all the libraries in the PyData ecosystem.
Prior art
Proposal:
dtype
kwarg is for consistency withsum
et alcc @oleksandr-pavlyk
The text was updated successfully, but these errors were encountered: