-
Notifications
You must be signed in to change notification settings - Fork 9
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
correctly infer backend for functions like concatenate #3
Comments
Yes this is a good point that I've run into a few times. Regarding your approaches:
A dispatcher might look like: def einsum_dispatcher(*args, **kwargs):
if isinstance(args[0], str):
return args[1]
return args[0]
register_dispatch('einsum', einsum_dispatcher) and then in if like is None:
dispatch_arg = _DISPATCHERS.get(fn, default_dispatcher)(*args, **kwargs)
backend = infer_backend(dispatch_arg) with def default_dispatcher(*args, **kwargs):
return args[0] hopefully just having the one extra dict lookup woudn't have too much overhead, possibly a On the manually specifying side I've been wondering if something like with set_backend(like):
... to force a backend for a block would be sufficiently convenient to add, but the above solution I think is cleaner for this particular problem. Depending on your thoughts I'm very happy to implement the custom |
A good point, my second suggestion will indeed not work. I'll try and see if I can get the custom dispatcher approach to work. It doesn't seem like it would be too difficult. I really like this project, so I would love to contribute. I use this extensively in a project of mine, and doing a PR is a nice excuse to dive a bit deeper into the workings of your code. I also have a bunch of extra translations I needed for my project, but I will come back to that once the project is nearer to being finished. |
Resolved by #4. Glad |
When using
concatenate
, the result is always converted to numpy arrays. E.g.This can be mitigated by instead doing
but this is a bit unwieldy. The problem is that the argument
(A,B)
is a tuple, which belongs to backendbuiltins
, which in turn always gets inferred asnumpy
byinfer_backend
.This problem applies to any function whose first argument is a list/tuple of arrays. I know at least that this applied to
concatenate
,einsum
andstack
. Foreinsum
I just opted to callopt_einsum
directly, which does correctly infer backend in this case, but that is besides the point.I can see several possible approaches:
ar.register_function
the user should also be able to indicate the function is of this type._infer_class_backend_cached
make a specific check forbuiltins
: we check if the item is iterable, if so we check the backend of the first element. If it is againbuiltins
, then leave it as is, but if it is something else then return that backend instead.I'm partial to the second option, as I don't expect it to have too many side-effects. If you want I can do a PR.
The text was updated successfully, but these errors were encountered: