-
Notifications
You must be signed in to change notification settings - Fork 50
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
A "lazy" / "meta" implementation of the array api? #728
Comments
This would be super useful indeed. It's not a small amount of work I suspect. For indexing there is https://github.com/Quansight-Labs/ndindex/, which basically implements this "meta" idea. That's probably one of the most hairy parts to do, and a good start. But correctly doing all shape calculations for all functions in the API is also a large job. Perhaps others know of reusable functionality elsewhere for this? For PyTorch I believe it's too much baked into the library to be able to reuse it standalone. |
Thanks a lot @rgommers for the response!
I only partly agree, because there is nothing particular difficult about it. The expected behavior is well defined, the API is already defined, so there is no tricky code to be figured out. It is just a matter of implementing the already defined behavior "dilligently".
This is indeed a great start, I was not aware of this project.
I think the effort could actually be limited, because looking at https://github.com/numpy/numpy/tree/main/numpy/array_api, the files already pre-group the api into operations with the same behavior in terms of shape computation, i.e. element wise, indexing, searching, statistical, etc. For each group the behavior only needs to defined once, the rest is filling in boiler plate code. In addition there is broadcasting and indexing, which always applies. I'm less sure about the dtype promotion, but this must have been coded somewhere already as well.
I agree PyTorch is already to large of a dependency. From a quick search I only found https://github.com/NeuralEnsemble/lazyarray, which seems to be un-maintained. It also has a different approach of building a graph and then delay the evaluation. |
I'd like to get a better idea of the actual implementation effort and just share some more thoughts on this idea.
|
Is this not a valid shape? The spec for
|
Indeed this is already part of the spec. I missed that before! |
Actually I'd be interested in starting a repo and playing around with this a bit. Any preference for a name @rgommers or @lucascolley? What about |
This issue was moved to a discussion.
You can continue the conversation there. Go to discussion →
In addition to the already available implementations of the array api I think it could be interesting to have a lazy / meta implementation of the standard. What I mean is a small, minimal dependency, standalone library, compatible with the array api, that provides inference of the shape and dtype of resulting arrays, without ever initializing the data and executing any flops.
PyTorch already has something like this with the
"meta"
device. For example:However this misses for example the device handling, as the device is constrained to
"meta"
. I presume that dask must have something very similar. Jax also must have something very similar for the jitted computations. However I think it is only exposed to users with a different API viajax.eval_shape()
and not via an "meta" array object.Similarly to the torch example one would use a hypothetical library
lazy_array_api
:The use case I have in mind is mostly debugging, validation and testing of computational intense algorithms ("dry runs"). For now I just wanted to share the idea and bring it up for discussion.
The text was updated successfully, but these errors were encountered: