Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add streamlined _ensure_arraylike utility #14921

Open
jakevdp opened this issue Mar 10, 2023 · 2 comments
Open

Add streamlined _ensure_arraylike utility #14921

jakevdp opened this issue Mar 10, 2023 · 2 comments
Assignees
Labels
cleanup Code cleanups and improvements

Comments

@jakevdp
Copy link
Collaborator

jakevdp commented Mar 10, 2023

Over time we've landed on the following pattern for user-facing APIs in jax.numpy, jax.scipy, jax.random, etc.:

import jax.numpy as jnp
from jax._src.numpy.util import _check_arraylike
from jax.typing import ArrayLike
from jax import Array

def func(x: ArrayLike, ...) -> Array:
  _check_arraylike("func", x)  # this is essentially a runtime check of isinstance(x, ArrayLike)
  x = jnp.asarray(x)  # this actually converts x to Array so that, e.g. it can be passed to lax functions
  # do some computation with x
  return out

There are a couple improvements we could make here:

  1. We essentially never want to call _check_arraylike without following with jnp.asarray on the inputs. Whe should combine these into a single utility! This would prevent failures like the one reported in jnp.sort and jnp.vdot raise error under jax.disable_jit() #14906.
  2. jnp.asarray is a pretty heavy function; if we know that an input passes _check_arraylike, we can convert to an array much more efficiently.
  3. This unified utility might also handle the stackable checks that are currently in numpy, via an extra keyword like allow_stackable=True. Also perhaps incorporate check_no_float0s via allow_float0s=False.
  4. _check_arraylike lives in jax._src.numpy.util, but it is far more generally applicable. It should probably live in jax._util, or perhaps even be a public utility so that we can recommend this pattern for downstream APIs (e.g. here: https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices)
@jakevdp jakevdp added the cleanup Code cleanups and improvements label Mar 10, 2023
@jakevdp jakevdp self-assigned this Mar 10, 2023
@soraros
Copy link

soraros commented Mar 10, 2023

@jakevdp Thanks for investigating this!

Wow, I didn't know the existence of this (jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices). It's such a nature place where all the array API related (type checking, dtype checking, dtype promotion, staticness, etc.) tools can go. Do you think we could open an umbrella issue to track the inception of such a set of (end user facing) tools?

@jakevdp
Copy link
Collaborator Author

jakevdp commented Mar 10, 2023

That doc is only a week or two old. #12049 is a relevant tracking issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cleanup Code cleanups and improvements
Projects
None yet
Development

No branches or pull requests

2 participants