Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A "lazy" / "meta" implementation of the array api? #728

Closed
adonath opened this issue Jan 11, 2024 · 6 comments
Closed

A "lazy" / "meta" implementation of the array api? #728

adonath opened this issue Jan 11, 2024 · 6 comments

Comments

@adonath
Copy link

adonath commented Jan 11, 2024

In addition to the already available implementations of the array api I think it could be interesting to have a lazy / meta implementation of the standard. What I mean is a small, minimal dependency, standalone library, compatible with the array api, that provides inference of the shape and dtype of resulting arrays, without ever initializing the data and executing any flops.

PyTorch already has something like this with the "meta" device. For example:

import torch

data = torch.ones((1000, 1000, 100), device="meta")
kernel = torch.ones((100, 10), device="meta")

result = torch.matmul(data, kernel)
print(result.shape)
print(result.dtype)

However this misses for example the device handling, as the device is constrained to "meta". I presume that dask must have something very similar. Jax also must have something very similar for the jitted computations. However I think it is only exposed to users with a different API via jax.eval_shape() and not via an "meta" array object.

Similarly to the torch example one would use a hypothetical library lazy_array_api:

import lazy_array_api as xp

data = xp.ones((1000, 1000, 100), device="cpu")
kernel = xp.ones((100, 10), device="cpu")

result = xp.matmul(data, kernel)
print(result.shape)
print(result.dtype)

The use case I have in mind is mostly debugging, validation and testing of computational intense algorithms ("dry runs"). For now I just wanted to share the idea and bring it up for discussion.

@rgommers
Copy link
Member

This would be super useful indeed. It's not a small amount of work I suspect. For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea. That's probably one of the most hairy parts to do, and a good start. But correctly doing all shape calculations for all functions in the API is also a large job. Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone.

@adonath
Copy link
Author

adonath commented Jan 12, 2024

Thanks a lot @rgommers for the response!

It's not a small amount of work I suspect.

I only partly agree, because there is nothing particular difficult about it. The expected behavior is well defined, the API is already defined, so there is no tricky code to be figured out. It is just a matter of implementing the already defined behavior "dilligently".

For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea.

This is indeed a great start, I was not aware of this project.

But correctly doing all shape calculations for all functions in the API is also a large job.

I think the effort could actually be limited, because looking at https://github.com/numpy/numpy/tree/main/numpy/array_api, the files already pre-group the api into operations with the same behavior in terms of shape computation, i.e. element wise, indexing, searching, statistical, etc. For each group the behavior only needs to defined once, the rest is filling in boiler plate code. In addition there is broadcasting and indexing, which always applies. I'm less sure about the dtype promotion, but this must have been coded somewhere already as well.

Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone.

I agree PyTorch is already to large of a dependency. From a quick search I only found https://github.com/NeuralEnsemble/lazyarray, which seems to be un-maintained. It also has a different approach of building a graph and then delay the evaluation.

@adonath
Copy link
Author

adonath commented Jan 22, 2024

I'd like to get a better idea of the actual implementation effort and just share some more thoughts on this idea.

  • The main use case for this would be shape inference and dtype inference. I guess it makes sense to really consider it as an implementation of the standard, with a corresponding array object, operators and functions. Shapes and dtypes are directly computed, there will be no graph like handling / delayed execution. This is out of scope.
  • Data dependent methods, such as unique cannot really be supported. Only if one allows "un-initialized" shapes (see below)
  • Sometimes it might useful to work with "un-initialized shapes". For example the length of a "batch axes" might not be known and often one would not care about it either. As "batches" are treated independent (None, 2, 3, 4) could be handled as a valid shape. After specific operations, such as the mean along the batch axis, the shape becomes known.
    However it is not a valid shape in the Array-API.
  • Typically users / developers would change from the lazy implementation to one supporting data. This might be relevant because, they would like to work with dtypes and device supported by the actual implementation. So having some backend specific device validation even for lazy implementation would be good.
  • In general I guess one would use https://github.com/numpy/numpy/tree/main/numpy/array_api as a starting point, delete all actual functionality and use it as an API template. The type promotion table and return type handling could be kept.
  • ndindex is great! It also already provides https://quansight-labs.github.io/ndindex/api.html#ndindex.broadcast_shapes and does the indexing. So it will definitely be a dependency.
  • ndindex has Numpy and Cython as an optional dependency. I think the situation will be exactly the same for such a lazy implementation. I don't think anything else would be needed.

@lucascolley
Copy link
Contributor

Sometimes it might useful to work with "un-initialized shapes". For example the length of a "batch axes" might not be known and often one would not care about it either. As "batches" are treated independent (None, 2, 3, 4) could be handled as a valid shape. After specific operations, such as the mean along the batch axis, the shape becomes known.
However it is not a valid shape in the Array-API.

Is this not a valid shape? The spec for shape says

out (Tuple[Optional[int], …]) – array dimensions. An array dimension must be None if and only if a dimension is unknown.

@adonath
Copy link
Author

adonath commented Jan 22, 2024

Indeed this is already part of the spec. I missed that before!

@adonath
Copy link
Author

adonath commented Feb 5, 2024

Actually I'd be interested in starting a repo and playing around with this a bit. Any preference for a name @rgommers or @lucascolley? What about ndshape or xpshape for example?

@data-apis data-apis locked and limited conversation to collaborators Apr 4, 2024
@kgryte kgryte converted this issue into discussion #777 Apr 4, 2024

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants