-
Notifications
You must be signed in to change notification settings - Fork 2.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better documentation for jnp.load #24403
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -32,10 +32,11 @@ | |
import importlib | ||
import math | ||
import operator | ||
import os | ||
import string | ||
import types | ||
from typing import ( Any, Literal, NamedTuple, | ||
Protocol, TypeVar, Union,overload) | ||
from typing import (Any, IO, Literal, NamedTuple, | ||
Protocol, TypeVar, Union, overload) | ||
import warnings | ||
|
||
import jax | ||
|
@@ -320,11 +321,43 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array: | |
return clip(val, min_val, max_val).astype(dtype) | ||
|
||
|
||
@util.implements(np.load, update_doc=False) | ||
def load(*args: Any, **kwargs: Any) -> Array: | ||
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: | ||
"""Load JAX arrays from npy files. | ||
|
||
JAX wrapper of :func:`numpy.load`. | ||
|
||
This function is a simple wrapper of :func:`numpy.load`, but in the case of | ||
``.npy`` files created with :func:`numpy.save` or :func:`jax.numpy.save`, | ||
the output will be returned as a :class:`jax.Array`, and ``bfloat16`` data | ||
types will be restored. For ``.npz`` files, results will be returned as | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect it should be possible to return There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. loading an |
||
normal NumPy arrays. | ||
|
||
This function requires concrete array inputs, and is not compatible with | ||
transformations like :func:`jax.jit` or :func:`jax.vmap`. | ||
|
||
Args: | ||
file: string, bytes, or path-like object containing the array data. | ||
args, kwargs: for additional arguments, see :func:`numpy.load` | ||
|
||
Returns: | ||
the array stored in the file. | ||
|
||
See also: | ||
- :func:`jax.numpy.save`: save an array to a file. | ||
|
||
Examples: | ||
>>> import io | ||
>>> f = io.BytesIO() # use an in-memory file-like object. | ||
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16') | ||
>>> jnp.save(f, x) | ||
>>> f.seek(0) | ||
0 | ||
>>> jnp.load(f) | ||
Array([2, 4, 6, 8], dtype=bfloat16) | ||
""" | ||
# The main purpose of this wrapper is to recover bfloat16 data types. | ||
# Note: this will only work for files created via np.save(), not np.savez(). | ||
out = np.load(*args, **kwargs) | ||
out = np.load(file, *args, **kwargs) | ||
if isinstance(out, np.ndarray): | ||
# numpy does not recognize bfloat16, so arrays are serialized as void16 | ||
if out.dtype == 'V2': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC can we use
os.PathLike[str]
here or do we want to allow bothstr
andbytes
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm honestly not sure, I thought
PathLike[Any]
would be the safest annotation here.