-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add basic JAX support #84
Changes from 3 commits
2ee6902
12b5294
6bf5dad
583f6bb
9c8bed6
6d59ae8
ce07cd9
ddb313e
6004b97
701a5ef
049d557
db667ea
aafbbaa
919ec41
fa758f7
6c338ca
244462f
bff9bf2
264e6c3
e7aff0f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -55,9 +55,9 @@ def is_jax_array(x): | |
if 'jax' not in sys.modules: | ||
return False | ||
|
||
import jax.numpy | ||
import jax | ||
|
||
return isinstance(x, jax.numpy.ndarray) | ||
return isinstance(x, jax.Array) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should this be more guarded? e.g. what if someone has a module named There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't know. That sort of thing tends to just break everything anyway. I've never really felt that libraries should protect against that sort of thing. Anyway, the whole point of this function is to be guarded. It won't import There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More precisely, it won't import A similar issue exists for every other package name referenced in this module. My feeling is: it costs virtually nothing to wrap this all in an appropriate There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm still not so sure this is a good idea. Usually if you have that sort of thing it will be an error for a lot of things, not just array_api_compat. My worry here is that guarding isn't as straightforward as it might seem. Wrapping everything in try/except could mean we end up silencing legitimate errors. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. OK, sounds good. |
||
|
||
def is_array_api_obj(x): | ||
""" | ||
|
@@ -153,7 +153,7 @@ def _check_device(xp, device): | |
if device not in ["cpu", None]: | ||
raise ValueError(f"Unsupported device for NumPy: {device!r}") | ||
|
||
# device() is not on numpy.ndarray and and to_device() is not on numpy.ndarray | ||
# device() is not on numpy.ndarray and to_device() is not on numpy.ndarray | ||
# or cupy.ndarray. They are not included in array objects of this library | ||
# because this library just reuses the respective ndarray classes without | ||
# wrapping or subclassing them. These helper functions can be used instead of | ||
|
@@ -230,12 +230,6 @@ def _torch_to_device(x, device, /, stream=None): | |
raise NotImplementedError | ||
return x.to(device) | ||
|
||
def _jax_to_device(x, device, /, stream=None): | ||
import jax | ||
if stream is not None: | ||
raise NotImplementedError | ||
return jax.device_put(x, device) | ||
|
||
def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, Any]]" = None) -> "Array": | ||
""" | ||
Copy the array from the device on which it currently resides to the specified ``device``. | ||
|
@@ -276,7 +270,9 @@ def to_device(x: "Array", device: "Device", /, *, stream: "Optional[Union[int, A | |
return x | ||
raise ValueError(f"Unsupported device {device!r}") | ||
elif is_jax_array(x): | ||
return _jax_to_device(x, device, stream=stream) | ||
# This import adds to_device to x | ||
import jax.experimental.array_api | ||
return x.to_device(device, stream=stream) | ||
return x.to_device(device, stream=stream) | ||
|
||
def size(x): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jakevdp, I could mostly use your review for the changes in this file.