-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better documentation for jnp.load #24403
Conversation
@@ -320,11 +321,43 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: | |||
return clip(val, min_val, max_val).astype(dtype) | |||
|
|||
|
|||
@util.implements(np.load, update_doc=False) | |||
def load(*args: Any, **kwargs: Any) -> Array: | |||
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC can we use os.PathLike[str]
here or do we want to allow both str
and bytes
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly not sure, I thought PathLike[Any]
would be the safest annotation here.
This function is a simple wrapper of :func:`numpy.load`, but in the case of | ||
``.npy`` files created with :func:`numpy.save` or :func:`jax.numpy.save`, | ||
the output will be returned as a :class:`jax.Array`, and ``bfloat16`` data | ||
types will be restored. For ``.npz`` files, results will be returned as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I suspect it should be possible to return jax.Array
s for .npz
files as well. Is there a reason why we don't do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loading an npz
returns a lazy view of the file buffer, with various methods to actually load the (NumPy) arrays. Having it return JAX arrays would require re-implementing that object. Not impossible, but not something anyone has sat down and done.
Part of #21461