-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Proposal: add APIs for getting and setting elements via a list of indices (i.e., take
, put
, etc)
#177
Comments
Thanks @kgryte. A few initial thoughts:
|
@rgommers Thanks for the comments.
|
A natural question is if |
@asmeurer I think that |
If the size of |
@rgommers Correct; however, I still could imagine that data flows involving a |
@asmeurer Re: integer array indexing. As mentioned during the previous call (03/06/2021), similar to boolean array indexing, could support a limited form of integer array indexing, where the integer array index is the sole index. Meaning, the spec would not condone mixing boolean with integer or require broadcasting semantics among the various indices. |
Cross-linking to a discussion regarding issues concerning out-of-bounds access in |
In the ML use case, it is common to want to sample with replacement or shuffle a dataset. This is commonly done by sampling an integer array and using it to subset the dataset: import numpy.array_api as xp
X = xp.asarray([[1, 2, 3, 4], [2, 3, 4, 5],
[4, 5, 6, 10], [5, 6, 8, 20]], dtype=xp.float64)
sample_indices = xp.asarray([0, 0, 1, 3])
# Does not work
# X[sample_indices, :] For libraries that need selection with integer arrays, a work around is to implement def take(X, indices, *, axis):
# Simple implementation that only works for axis in {0, 1}
if axis == 0:
selected = [X[i] for i in indices]
else: # axis == 1
selected = [X[:, i] for i in indices]
return xp.stack(selected, axis=axis)
take(X, sample_indices, axis=0) Note that sampling with replacement can not be done with a boolean mask, because some rows may be selected twice. |
Hi @kmaehashi @asi1024 @emcastillo FYI. In a recent array API call we discussed about the proposed take/put APIs, and there were questions regarding how CuPy currently implements these functions, as there could be data/value dependency and people were wondering if we just have to pay the synchronization cost to ensure the behavior is correct. Could you help address? Thanks! (And sorry I dropped the ball here...) |
+1 I think "array only" integer indexing would be quite well defined, and would not be problematic for accelerators. The main challenge with NumPy's implementation of "advanced indexing" is handling mixed integer/slice/boolean cases. |
Here is a summary of today's discussion:
Given all that, the proposal is to only add |
Something that I think was missed in today's discussion is that >>> a = np.arange(9).reshape((3, 3)) + 10
>>> a[np.array([0, 2]), np.array([1, 2])]
array([11, 18]) >>> np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3))
array([1, 8])
>>> np.take(a, np.ravel_multi_index((np.array([0, 2]), np.array([1, 2])), (3, 3)))
array([11, 18])
I'm not sure if there's an easy way within the array API to go from one to the other. And I hope the the "integer array as the sole index" idea above was really meant to be "integer arrays as the sole indices". Just having a single integer array index means you can only index the first dimension of the array. This should also include integer scalars, as those are equivalent to 0-D arrays, unless we want to omit the "all integer array indices are broadcast together" rule. I agree that NumPy's rules for mixing arrays with slices should not be included, especially the crazy rule about how it handles slices between integer array indices, which a design mistake in NumPy (slices around integer array indices isn't so bad, and can be useful, but also adds complexity to the indexing rules so I can see wanting to omit it). |
It's definitely possible (but not necessarily easy) to rewrite every call to I would not be opposed to adding
If we do choose to include |
The suggestion here is to support
I've never really used |
What is the concern with supporting integer array indexing in |
That, and also that it's non-deterministic when indices are not unique, as noted in the PyTorch and TF docs on |
I think we could probably safely leave this as undefined behavior? |
Yes, fair enough. Let me add another concern though, probably the main one (copied from higher up, with a minor edit: I think my preferred order of doing things here would be:
|
Very interested in having at least a basic version of Context: experimental version of more verbose indexing, see arogozhnikov/einops#194 for details |
Thanks for the ping on this issue @arogozhnikov - and nice to see the experimental work on indexing in |
Support for |
nit. we also have |
I think if we were to introduce something like
The question areas then would be
|
Some thoughts on @honno's points
|
|
Broadcasting indices should give, by definition, repeated indices, which should invoke UB (which value is written to the index 0 in your example?). |
no, see my example above. Compare with:
|
We should clarify in the spec that behavior on out-of-bounds indices is unspecified. The |
I had a look at implementations in libraries, some updates on what's in the issue description:
def put(*args, **kwargs):
raise NotImplementedError(
"jax.numpy.put is not implemented because JAX arrays cannot be modified in-place. "
"For functional approaches to updating array values, see jax.numpy.ndarray.at: "
"https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html.") for similar reasons as it avoids other in-place APIs (xref design_topics/copies_views_and_mutation). So I think we should consider the addition of |
@arogozhnikov I think your example doesn't do what you think it does. Consider >>> x = np.asarray([[0, 1]])
>>> np.put(x, [0], [[2, 3]])
>>> x
array([[2, 1]]) In this case, it's not that it's being broadcasted, but that In general, you are talking about "rows", but |
indeed, for some reason I though it is somewhat a shortcut for |
Given the above, I will go ahead and close this issue, as we are unlikely to make progress on |
Can this issue be reopened for the In case it is relevant, here is an array-API compatible version of import numpy as np
import array_api_strict
from array_api_compat import array_namespace
def xp_swapaxes(a, axis1, axis2, *, xp=None):
xp = array_namespace(a) if xp is None else xp
axes = list(range(a.ndim))
axes[axis1], axes[axis2] = axes[axis2], axes[axis1]
a = xp.permute_dims(a, axes)
return a
def xp_take_along_axis(arr, indices, axis, *, xp=None):
xp = array_namespace(arr) if xp is None else xp
arr = xp_swapaxes(arr, axis, -1, xp=xp)
indices = xp_swapaxes(indices, axis, -1, xp=xp)
m = arr.shape[-1]
n = indices.shape[-1]
shape = list(arr.shape)
shape.pop(-1)
shape = shape + [n,]
arr = xp.reshape(arr, (-1,))
indices = xp.reshape(indices, (-1, n))
offset = (xp.arange(indices.shape[0]) * m)[:, xp.newaxis]
indices = xp.reshape(offset + indices, (-1,))
out = xp.take(arr, indices)
out = xp.reshape(out, shape)
return xp_swapaxes(out, axis, -1, xp=xp)
rng = np.random.default_rng()
x = rng.random(size=(1000, 1000))
xp = array_api_strict
x = xp.asarray(x)
j = xp.argsort(x, axis=-1)
res = xp_take_along_axis(x, j, axis=-1)
ref = xp.sort(x, axis=-1)
assert xp.all(res == ref) |
I would really like to the see full integer coordinate-based indexing supported: #669 |
Proposal
Add APIs for getting and setting elements via a list of indices.
Motivation
Currently, the array API specification does not provide a direct means of extracting and setting a list of elements along an axis. Such operations are relatively common in NumPy usage either via "fancy indexing" or via explicit
take
andput
APIs.Two main arguments come to mind for supporting at least basic
take
andput
APIs:Indexing does not currently support providing a list of indices to index into an array. The principal reason for not supporting fancy indexing stems from dynamic shapes and compatibility with accelerator libraries. However, use of fancy indexing is relatively common in NumPy and similar libraries where dynamically extracting rows/cols/values is possible and can be readily implemented.
Currently, the output of a subset of APIs currently included in the standard cannot be readily consumed without manual workarounds if a specification-conforming library implemented only the APIs in the standard. For example,
argsort
returns an array of indices. In NumPy, the output of this function can be consumed byput_along_axis
andtake_along_axis
.unique
can return an array of indices ifreturn_index
isTrue
.Background
The following table summarizes library implementations of such APIs:
take
take
take
take
take
/gather
gather
/numpy.take
put
put
scatter
scatter_nd
/tensor_scatter_nd_update
take_along_axis
take_along_axis
gather_nd
/numpy.take_alongaxis
put_along_axis
While most libraries implement some form of
take
, fewer implement other complementary APIs.The text was updated successfully, but these errors were encountered: