Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better documentation for jnp.load #24403

Merged
merged 1 commit into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 38 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@
import importlib
import math
import operator
import os
import string
import types
from typing import ( Any, Literal, NamedTuple,
Protocol, TypeVar, Union,overload)
from typing import (Any, IO, Literal, NamedTuple,
Protocol, TypeVar, Union, overload)
import warnings

import jax
Expand Down Expand Up @@ -320,11 +321,43 @@ def _convert_and_clip_integer(val: ArrayLike, dtype: DType) -> Array:
return clip(val, min_val, max_val).astype(dtype)


@util.implements(np.load, update_doc=False)
def load(*args: Any, **kwargs: Any) -> Array:
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC can we use os.PathLike[str] here or do we want to allow both str and bytes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm honestly not sure, I thought PathLike[Any] would be the safest annotation here.

"""Load JAX arrays from npy files.

JAX wrapper of :func:`numpy.load`.

This function is a simple wrapper of :func:`numpy.load`, but in the case of
``.npy`` files created with :func:`numpy.save` or :func:`jax.numpy.save`,
the output will be returned as a :class:`jax.Array`, and ``bfloat16`` data
types will be restored. For ``.npz`` files, results will be returned as
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect it should be possible to return jax.Arrays for .npz files as well. Is there a reason why we don't do that?

Copy link
Collaborator Author

@jakevdp jakevdp Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loading an npz returns a lazy view of the file buffer, with various methods to actually load the (NumPy) arrays. Having it return JAX arrays would require re-implementing that object. Not impossible, but not something anyone has sat down and done.

normal NumPy arrays.

This function requires concrete array inputs, and is not compatible with
transformations like :func:`jax.jit` or :func:`jax.vmap`.

Args:
file: string, bytes, or path-like object containing the array data.
args, kwargs: for additional arguments, see :func:`numpy.load`

Returns:
the array stored in the file.

See also:
- :func:`jax.numpy.save`: save an array to a file.

Examples:
>>> import io
>>> f = io.BytesIO() # use an in-memory file-like object.
>>> x = jnp.array([2, 4, 6, 8], dtype='bfloat16')
>>> jnp.save(f, x)
>>> f.seek(0)
0
>>> jnp.load(f)
Array([2, 4, 6, 8], dtype=bfloat16)
"""
# The main purpose of this wrapper is to recover bfloat16 data types.
# Note: this will only work for files created via np.save(), not np.savez().
out = np.load(*args, **kwargs)
out = np.load(file, *args, **kwargs)
if isinstance(out, np.ndarray):
# numpy does not recognize bfloat16, so arrays are serialized as void16
if out.dtype == 'V2':
Expand Down
5 changes: 3 additions & 2 deletions jax/numpy/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ from __future__ import annotations

import builtins
from collections.abc import Callable, Sequence
from typing import Any, Literal, NamedTuple, Protocol, TypeVar, Union, overload
import os
from typing import Any, IO, Literal, NamedTuple, Protocol, TypeVar, Union, overload

from jax._src import core as _core
from jax._src import dtypes as _dtypes
Expand Down Expand Up @@ -577,7 +578,7 @@ def linspace(start: ArrayLike, stop: ArrayLike, num: int = 50,
axis: int = 0,
*, device: _Device | _Sharding | None = ...) -> Union[Array, tuple[Array, Array]]: ...

def load(*args: Any, **kwargs: Any) -> Array: ...
def load(file: IO[bytes] | str | os.PathLike[Any], *args: Any, **kwargs: Any) -> Array: ...
def log(x: ArrayLike, /) -> Array: ...
def log10(x: ArrayLike, /) -> Array: ...
def log1p(x: ArrayLike, /) -> Array: ...
Expand Down
Loading