You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Over time we've landed on the following pattern for user-facing APIs in jax.numpy, jax.scipy, jax.random, etc.:
importjax.numpyasjnpfromjax._src.numpy.utilimport_check_arraylikefromjax.typingimportArrayLikefromjaximportArraydeffunc(x: ArrayLike, ...) ->Array:
_check_arraylike("func", x) # this is essentially a runtime check of isinstance(x, ArrayLike)x=jnp.asarray(x) # this actually converts x to Array so that, e.g. it can be passed to lax functions# do some computation with xreturnout
There are a couple improvements we could make here:
We essentially never want to call _check_arraylike without following with jnp.asarray on the inputs. Whe should combine these into a single utility! This would prevent failures like the one reported in jnp.sort and jnp.vdot raise error under jax.disable_jit() #14906.
jnp.asarray is a pretty heavy function; if we know that an input passes _check_arraylike, we can convert to an array much more efficiently.
This unified utility might also handle the stackable checks that are currently in numpy, via an extra keyword like allow_stackable=True. Also perhaps incorporate check_no_float0s via allow_float0s=False.
Wow, I didn't know the existence of this (jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices). It's such a nature place where all the array API related (type checking, dtype checking, dtype promotion, staticness, etc.) tools can go. Do you think we could open an umbrella issue to track the inception of such a set of (end user facing) tools?
Over time we've landed on the following pattern for user-facing APIs in
jax.numpy
,jax.scipy
,jax.random
, etc.:There are a couple improvements we could make here:
_check_arraylike
without following withjnp.asarray
on the inputs. Whe should combine these into a single utility! This would prevent failures like the one reported injnp.sort
andjnp.vdot
raise error underjax.disable_jit()
#14906.jnp.asarray
is a pretty heavy function; if we know that an input passes_check_arraylike
, we can convert to an array much more efficiently.stackable
checks that are currently in numpy, via an extra keyword likeallow_stackable=True
. Also perhaps incorporatecheck_no_float0s
viaallow_float0s=False
._check_arraylike
lives injax._src.numpy.util
, but it is far more generally applicable. It should probably live injax._util
, or perhaps even be a public utility so that we can recommend this pattern for downstream APIs (e.g. here: https://jax.readthedocs.io/en/latest/jax.typing.html#jax-typing-best-practices)The text was updated successfully, but these errors were encountered: