Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proposal: add APIs for getting and setting elements via a list of indices (i.e., take, put, etc) #177

Open
kgryte opened this issue May 6, 2021 · 35 comments
Labels
API extension Adds new functions or objects to the API. topic: Dynamic Shapes Data-dependent shapes.

Comments

@kgryte
Copy link
Contributor

kgryte commented May 6, 2021

Proposal

Add APIs for getting and setting elements via a list of indices.

Motivation

Currently, the array API specification does not provide a direct means of extracting and setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or via explicit take and put APIs.

Two main arguments come to mind for supporting at least basic take and put APIs:

  1. Indexing does not currently support providing a list of indices to index into an array. The principal reason for not supporting fancy indexing stems from dynamic shapes and compatibility with accelerator libraries. However, use of fancy indexing is relatively common in NumPy and similar libraries where dynamically extracting rows/cols/values is possible and can be readily implemented.

  2. Currently, the output of a subset of APIs currently included in the standard cannot be readily consumed without manual workarounds if a specification-conforming library implemented only the APIs in the standard. For example,

    • argsort returns an array of indices. In NumPy, the output of this function can be consumed by put_along_axis and take_along_axis.
    • unique can return an array of indices if return_index is True.

Background

The following table summarizes library implementations of such APIs:

op NumPy CuPy Dask MXNet Torch TensorFlow
extracting elements along axis take take take take take/gather gather/numpy.take
setting elements along axis put put -- -- scatter scatter_nd/tensor_scatter_nd_update
extracting elements over matching 1d slices take_along_axis take_along_axis -- -- -- gather_nd/numpy.take_alongaxis
setting elements over matching 1d slices put_along_axis -- -- -- -- --

While most libraries implement some form of take, fewer implement other complementary APIs.

@rgommers rgommers added the API extension Adds new functions or objects to the API. label May 6, 2021
@rgommers
Copy link
Member

rgommers commented May 6, 2021

Thanks @kgryte. A few initial thoughts:

  • There is also overlap between take/put and various scatter/gather functions in TensorFlow, PyTorch and MXNet. There's a whole host of those functions.
  • Is there really an issue with shape determinism? I'm probably missing something here, but isn't the output size along the given dimension equal to indices.size? And put is an inplace operation which doesn't change the shape.
  • If we do want to add these, we may consider putting them in the second version of the API. Just thinking that we should at some point stop making the API a permanently moving target.

@kgryte
Copy link
Contributor Author

kgryte commented May 6, 2021

@rgommers Thanks for the comments.

  1. Correct. I've updated the table with Torch and TF scatter and gather methods.
  2. Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.
  3. Not opposed to delaying until V2 (2022).

@asmeurer
Copy link
Member

asmeurer commented May 6, 2021

A natural question is if take is supported, is there any reason equivalent indexing shouldn't also be supported. Granted, take only represents a specific subset of general (NumPy) integer array indexing, where indexing is done on a single axis.

@kgryte
Copy link
Contributor Author

kgryte commented May 6, 2021

@asmeurer I think that take would be an optional API; whereas indexing semantics should be universal.

@rgommers
Copy link
Member

rgommers commented May 7, 2021

Correct me if I am wrong, but indices.size need not be fixed and could be data-dependent. For example, if extract the indices of unique elements from an array, the number of indices cannot necessarily be known AOT.

If the size of indices is variable, it's the function that produces indices that is data-dependent. take itself however is not. Compare with boolean indexing or nonzero, there the output size is in the range [0, x_input.size]; for take it's always x_input.size.

@kgryte
Copy link
Contributor Author

kgryte commented May 11, 2021

@rgommers Correct; however, I still could imagine that data flows involving a take operation may still be problematic for AOT computational graphs. While the output size is indices.size, an array library may not be able to statically allocate memory for the output of the take operation. This said, accelerator libraries do manage to support similar APIs (e.g., scatter/gather), so probably no need to further belabor this.

@kgryte
Copy link
Contributor Author

kgryte commented May 11, 2021

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

@kgryte
Copy link
Contributor Author

kgryte commented May 11, 2021

Cross-linking to a discussion regarding issues concerning out-of-bounds access in take APIs for accelerator libraries.

@kgryte kgryte added this to the v2022 milestone Oct 4, 2021
@thomasjpfan
Copy link

thomasjpfan commented Dec 7, 2021

In the ML use case, it is common to want to sample with replacement or shuffle a dataset. This is commonly done by sampling an integer array and using it to subset the dataset:

import numpy.array_api as xp

X = xp.asarray([[1, 2, 3, 4], [2, 3, 4, 5],
                [4, 5, 6, 10], [5, 6, 8, 20]], dtype=xp.float64)

sample_indices = xp.asarray([0, 0, 1, 3])

# Does not work
# X[sample_indices, :]

For libraries that need selection with integer arrays, a work around is to implement take:

def take(X, indices, *, axis):
    # Simple implementation that only works for axis in {0, 1}
    if axis == 0:
        selected = [X[i] for i in indices]
    else:  # axis == 1
        selected = [X[:, i] for i in indices]
    return xp.stack(selected, axis=axis)

take(X, sample_indices, axis=0)

Note that sampling with replacement can not be done with a boolean mask, because some rows may be selected twice.

@leofang
Copy link
Contributor

leofang commented Feb 1, 2022

Hi @kmaehashi @asi1024 @emcastillo FYI. In a recent array API call we discussed about the proposed take/put APIs, and there were questions regarding how CuPy currently implements these functions, as there could be data/value dependency and people were wondering if we just have to pay the synchronization cost to ensure the behavior is correct. Could you help address? Thanks! (And sorry I dropped the ball here...)

@shoyer
Copy link
Contributor

shoyer commented Mar 10, 2022

@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices.

+1 I think "array only" integer indexing would be quite well defined, and would not be problematic for accelerators. The main challenge with NumPy's implementation of "advanced indexing" is handling mixed integer/slice/boolean cases.

@rgommers
Copy link
Member

Here is a summary of today's discussion:

  • Implementing take is fine, there's no problem for accelerators and all libraries listed above already have this API. Given that they all have it, there's no problem adding take to the standard right now.
  • The __getitem__ part of indexing is equivalent to take. However, as @asmeurer pointed out, it would be odd to add support for integer array indexing in __getitem__ but not in __setitem__. Hence we need to look at the latter.
  • put and __setitem__ are also equivalent - and more problematic, for multiple reasons:
    • as the table in the issue description shows, put isn't widely supported across libraries, and not with the same name either.
    • put is explicitly an in-place function in NumPy et al., which is a problem for JAX/TensorFlow. Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function like put.
    • @oleksandr-pavlyk suggested adding a new out of place version of put to the standard. However, that's a new function that libraries don't yet have (actually some do under names like index_put, but it's a mixed bag). And it's not clear that this would be preferred in the long term; an inplace put that is guaranteed to raise when it crosses paths with a view may be better.

Given all that, the proposal is to only add take now, and revisit integer array indexing and put in the future.

@asmeurer
Copy link
Member

Something that I think was missed in today's discussion is that take and put aren't exactly the same as integer array indexing. Integer array indices operate on the axes of the array. take and put (at least in NumPy) operate on the flattened array.

>>> a = np.arange(9).reshape((3, 3)) + 10
>>> a[np.array([0, 2]), np.array([1, 2])]
array([11, 18])
>>> np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3))
array([1, 8])
>>> np.take(a, np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3)))
array([11, 18])

np.take also has an axis parameter but that's only equivalent to a single integer array index.

I'm not sure if there's an easy way within the array API to go from one to the other.

And I hope the the "integer array as the sole index" idea above was really meant to be "integer arrays as the sole indices". Just having a single integer array index means you can only index the first dimension of the array. This should also include integer scalars, as those are equivalent to 0-D arrays, unless we want to omit the "all integer array indices are broadcast together" rule.

I agree that NumPy's rules for mixing arrays with slices should not be included, especially the crazy rule about how it handles slices between integer array indices, which a design mistake in NumPy (slices around integer array indices isn't so bad, and can be useful, but also adds complexity to the indexing rules so I can see wanting to omit it).

@shoyer
Copy link
Contributor

shoyer commented Mar 24, 2022

It's definitely possible (but not necessarily easy) to rewrite every call to np.take in terms of __getitem__ with integer arrays. For a library like Xarray, support for all integer indexing (especially with broadcasting) would be sufficient. So from my perspective, support for all integer indexing in __getitem__ and possibly also __setitem__ would the most useful functionality.

I would not be opposed to adding take if there is interest. It certainly is easier to construct calls to take, and knowing ahead of time that indexing is only going along a certain axis can sometimes allow for significant simplifications to indexing code. There are two alternatives we could consider for filling this same niche (easy integer based indexing along one dimension):

  1. Support for mixed array/slice indexing, like NumPy. But like Aaron says, this is too confusing for the API standard.
  2. We could include oindex, but this proposal never got entirely off the ground (beyond implementations in Xarray/Dask/Zarr).

If we do choose to include ake in the standard, the axis argument should be required. Slicing along flattened arrays is not very useful.

@asmeurer
Copy link
Member

It's definitely possible (but not necessarily easy) to rewrite every call to np.take in terms of getitem with integer arrays. For a library like Xarray, support for all integer indexing (especially with broadcasting) would be sufficient. So from my perspective, support for all integer indexing in getitem and possibly also setitem would the most useful functionality.

The suggestion here is to support take but defer support for indexing. So users of the array API would need to rewrite usages of __getitem__ to take, not the other way around.

Slicing along flattened arrays is not very useful.

I've never really used take myself, so I don't have the best context here, but isn't the flattened behavior there to match put, which doesn't have axis?

@shoyer
Copy link
Contributor

shoyer commented Mar 25, 2022

What is the concern with supporting integer array indexing in __setitem__? Just the fact that it may not be implemented in otherwise compliant array libraries?

@rgommers
Copy link
Member

What is the concern with supporting integer array indexing in __setitem__? Just the fact that it may not be implemented in otherwise compliant array libraries?

That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on scatter/scatter_nd.

@shoyer
Copy link
Contributor

shoyer commented Mar 30, 2022

That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on scatter/scatter_nd.

I think we could probably safely leave this as undefined behavior?

@rgommers
Copy link
Member

rgommers commented Apr 7, 2022

I think we could probably safely leave this as undefined behavior?

Yes, fair enough.

Let me add another concern though, probably the main one (copied from higher up, with a minor edit: put --> __setitem__): Having a better handle on the topic of mutability looks like a hard requirement before even considering an in-place function like __setitem__.

I think my preferred order of doing things here would be:

  1. Add take with 1-D integer array indices now (see Add take specification for returning elements of an array along a specified axis #416)
  2. Tighten up mutability specification
  3. Add __getitem__ and __setitem__ (with n-D integer inputs, assuming that behavior aligns across libraries).

@arogozhnikov
Copy link

Very interested in having at least a basic version of take to be incorporated into the standard.

Context: experimental version of more verbose indexing, see arogozhnikov/einops#194 for details

@rgommers
Copy link
Member

rgommers commented Jul 7, 2022

Thanks for the ping on this issue @arogozhnikov - and nice to see the experimental work on indexing in einops. I'd like to see gh-416 finished and merged in the coming days to indeed add take support with 1-D indices.

@rgommers
Copy link
Member

Support for take has been merged, see gh-416.

@rgommers rgommers removed this from the v2022 milestone Dec 14, 2022
@lezcano
Copy link
Contributor

lezcano commented Feb 15, 2023

nit. we also have put_ in PyTorch (but not put...)

@rgommers rgommers added the topic: Dynamic Shapes Data-dependent shapes. label Mar 9, 2023
@honno
Copy link
Member

honno commented Apr 18, 2023

I think if we were to introduce something like xp.put(x, indices, value) to the spec we can seemingly all agree on

  1. Only specifying a single array as the indices argument like we do with xp.take(), leaving other kinds of indices out-of-scope.

    e.g. for xp.put(x, indices, np.asarray([42, 7])) where x=xp.arange(5), indices=xp.asarray([1, 4]) would be supported, but the following equivalent arguments would be out-of-scope

    indices=(np.asarray(1), np.asarray(4))
    indices=(1, 4)
    indices=(1, np.asarray(4))
    indices=[1, 4]

    Array only indexing makes adoption easier and doesn't cause problems for accelerators.

  2. Keeping in-line with xp.take(), we should specify to only support indices as 1 dimensional.

    • Implicitly I'd be mandating that elements in indices relate to the index equivalent of the flattened equivalent of the input array, rather than any fancy broadcasting behaviours/etc..

      e.g. on the contrary, t.index_put_() use the shape of indices to specify multiple elements of the input array.

      >>> t = torch.as_tensor([[True, True]])
      >>> t.index_put_((torch.as_tensor([0]),), torch.as_tensor(False))
      tensor([[False, False]])
  3. Only specifying support for unique indices, e.g. indices=xp.asarray([0, 0]) would be out-of-scope. Consistent duplicate indices behaviour seems too niche and finicky to specify.

    • Interestingly PyTorch has the accumulate keyword for its t.put_()/t.index_put_() methods, where accumulate=False (default) leaves such behaviour unspecified, and accumulate=True puts the sum of the respective values.

The question areas then would be

  1. Should we support broadcasting value to the shape of indices?

    • np.put() broadcasts(?) the value to the indices, e.g.

      >>> a = np.asarray([True, True])
      >>> np.put(a, np.asarray([0, 1]), np.asarray(False))
      >>> a
      array([False, False])
    • On the contrary, torch.put_() requires the indices (index) to share the same shape as the value (source).

    As broadcasting is convenient and very common throughout the spec, IMO I'd specify value can be broadcasted to indices.

    Regardless I think we'd disallow broadcast-incompatible shapes, and value.size > indices.size scenarios.

  2. Should xp.put() return an array? What are the expectations for in-place and out-of-place behaviour?

    • The spec currently always(?) returns arrays for its functions, which seems a nice cadence to maintain.
    • Notably NumPy for its top-level np.put() currently only acts on the array in-place and does not return the modified array, whereas PyTorch has its put-like functions/methods return the modified array (some functions/methods also acting in-place).

    If we mandate xp.put() is to return the modified input array, we could leave in-place behaviour out-of-scope, or slap a copy keyword like we do for xp.asarray() and xp.reshape().

    Worth noting that the name put() also suggests in-place behaviour at this point.

@lezcano
Copy link
Contributor

lezcano commented Apr 18, 2023

Some thoughts on @honno's points

  1. I think that's because all those objects are array_likes in NumPy. The same happens for any function that accepts an array in NumPy.
  2. SGTM but perhaps extending it to "indices of dimension at most 1". I believe PyTorch accepts wlog any contiguous array of any size, but I've never seen it used with anything but arrays of dim 0 or 1
  3. Checking for repeated values is indeed too costly. I think it should be left unspecified.
  4. Either SGTM
  5. In PyTorch we return an array for in-place ops, following the C++ convention. This is sometimes used to be able to chain in-place ops. I think it'd be good to mandate this all throughout the API.

@arogozhnikov
Copy link

  1. you discussed broadcasting values to indices. What about broadcasting indices to values? (case of specific interest to me)
>>> a = np.asarray([[True, True]])
>>> np.put(a, np.asarray([0]), np.asarray([[False, True]]))
>>> a
array([[False,  True]])

@lezcano
Copy link
Contributor

lezcano commented Apr 18, 2023

Broadcasting indices should give, by definition, repeated indices, which should invoke UB (which value is written to the index 0 in your example?).

@arogozhnikov
Copy link

no, see my example above.
values consist of one row, and index specifies that first row of values should be assigned to first row of result.
I am not sure this is strictly the case of broadcasting, but that's a common thing to do.

Compare with:

matrix_n_by_n[[1, 2, 6]] = matrix_3_by_n

@asmeurer
Copy link
Member

We should clarify in the spec that behavior on out-of-bounds indices is unspecified. The take spec currently doesn't say anything about this (I'm assuming this is behavior we want since we already say this for basic integer indexing.

@rgommers
Copy link
Member

I had a look at implementations in libraries, some updates on what's in the issue description:

  • PyTorch, in addition to scatter, has Tensor.put_, so a method, and the trailing underscore indicating in-place behavior. A tensor is also returned.
  • Dask still doesn't have put or any of the other similar functions like putmask or put_along_axis. There doesn't seem to be a blocker though, it's only that no one has done the work yet (e.g., see Add NumPy's new put_along_axis dask/dask#3664 with a put_along_axis feature request).
  • JAX actually has it in its namespace (jax.numpy.put, but the implementation is:
def put(*args, **kwargs):
  raise NotImplementedError(
    "jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
    "For functional approaches to updating array values, see jax.numpy.ndarray.at: "
    "https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.")

for similar reasons as it avoids other in-place APIs (xref design_topics/copies_views_and_mutation).

So I think we should consider the addition of put feasible in principle but blocked right now. The JAX issue is most difficult to resolve (can be done, but a lot of work still to deal with read-only views or similar), but the lack of API uniformity makes this a hard sell in general.

@lezcano
Copy link
Contributor

lezcano commented Apr 19, 2023

@arogozhnikov I think your example doesn't do what you think it does. Consider

>>> x = np.asarray([[0, 1]])
>>> np.put(x, [0], [[2, 3]])
>>> x
array([[2, 1]])

In this case, it's not that it's being broadcasted, but that np.put just considers the first ind.size elements of v. See https://github.com/numpy/numpy/blob/6073588dd73809a60819d71b9527194195f73f08/numpy/core/src/multiarray/item_selection.c#L439

In general, you are talking about "rows", but put just sees the array as a flat chunk of memory, so there is no concept of rows and columns for this function.

@arogozhnikov
Copy link

indeed, for some reason I though it is somewhat a shortcut for x[ind] = val, but docs say that it operates on flat array. My bad!

@kgryte
Copy link
Contributor Author

kgryte commented Jun 29, 2023

So I think we should consider the addition of put feasible in principle but blocked right now. The JAX issue is most difficult to resolve (can be done, but a lot of work still to deal with read-only views or similar), but the lack of API uniformity makes this a hard sell in general.

Given the above, I will go ahead and close this issue, as we are unlikely to make progress on put in the near term. This issue can be reopened and revisited once we have a better handle on a path forward.

@mdhaber
Copy link
Contributor

mdhaber commented May 3, 2024

Can this issue be reopened for the take_along_axis portion? As noted in #416 (comment), the functionality is different from take, and although it can be implemented in terms of take (postscript), I haven't found a trivial way. It also looks like there is broad support now - in addition to the implementations mentioned in the top post, there are jax.numpy.take_along_axis and torch.take_along_dim.


In case it is relevant, here is an array-API compatible version of take_along_axis I've been using.

import numpy as np
import array_api_strict
from array_api_compat import array_namespace

def xp_swapaxes(a, axis1, axis2, *, xp=None):
    xp = array_namespace(a) if xp is None else xp
    axes = list(range(a.ndim))
    axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
    a = xp.permute_dims(a, axes)
    return a

def xp_take_along_axis(arr, indices, axis, *, xp=None):
    xp = array_namespace(arr) if xp is None else xp
    arr = xp_swapaxes(arr, axis, -1, xp=xp)
    indices = xp_swapaxes(indices, axis, -1, xp=xp)

    m = arr.shape[-1]
    n = indices.shape[-1]

    shape = list(arr.shape)
    shape.pop(-1)
    shape = shape + [n,]

    arr = xp.reshape(arr, (-1,))
    indices = xp.reshape(indices, (-1, n))

    offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
    indices = xp.reshape(offset + indices, (-1,))

    out = xp.take(arr, indices)
    out = xp.reshape(out, shape)
    return xp_swapaxes(out, axis, -1, xp=xp)

rng = np.random.default_rng()
x = rng.random(size=(1000, 1000))

xp = array_api_strict
x = xp.asarray(x)
j = xp.argsort(x, axis=-1)
res = xp_take_along_axis(x, j, axis=-1)
ref = xp.sort(x, axis=-1)
assert xp.all(res == ref)

@shoyer
Copy link
Contributor

shoyer commented May 3, 2024

I would really like to the see full integer coordinate-based indexing supported: #669

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
API extension Adds new functions or objects to the API. topic: Dynamic Shapes Data-dependent shapes.
Projects
None yet
Development

No branches or pull requests

10 participants