diff --git a/benchmarks/api_benchmark.py b/benchmarks/api_benchmark.py index df9528ada9ff..70b8a975fade 100644 --- a/benchmarks/api_benchmark.py +++ b/benchmarks/api_benchmark.py @@ -23,8 +23,8 @@ from jax import lax from jax._src.api_util import shaped_abstractify # technically not an api fn from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation +from jax._src import core from jax._src.lib import xla_client as xc -from jax.interpreters import xla from jax._src import array from jax._src import op_shardings from jax._src.pjit import pjit_check_aval_sharding @@ -427,7 +427,7 @@ def bench_shaped_abstractify(state): def _run_benchmark_for_xla_abstractify(arg, state): while state: - xla.abstractify(arg) + core.abstractify(arg) def bench_xla_abstractify(): _abstractify_args = [ diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index dbfbd3ff942b..b028f178603c 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -50,23 +50,49 @@ def canonical_concrete_aval(val, weak_type=None): sharding = core._get_abstract_sharding(val) return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding) + def masked_array_error(*args, **kwargs): raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. " "Use arr.filled() to convert the value to a standard numpy array.") core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error -for t in array_types: + +def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: + dtype = x.dtype + dtypes.check_valid_dtype(dtype) + return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) + +core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval +core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array + + +def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: + dtype = np.dtype(x) + dtypes.check_valid_dtype(dtype) + return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype)) + +for t in numpy_scalar_types: core.pytype_aval_mappings[t] = canonical_concrete_aval + core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar core.literalable_types.update(array_types) + def _make_concrete_python_scalar(t, x): dtype = dtypes._scalar_type_to_dtype(t, x) weak_type = dtypes.is_weakly_typed(x) return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type) + +def _make_abstract_python_scalar(typ, val): + # Note: all python scalar types are weak except bool, because bool only + # comes in a single width. + return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), + weak_type=typ is not bool) + for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t) + core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/api_util.py b/jax/_src/api_util.py index ad0d6c01dc8d..97052c955c69 100644 --- a/jax/_src/api_util.py +++ b/jax/_src/api_util.py @@ -600,7 +600,7 @@ def _shaped_abstractify_slow(x): "does not have a dtype attribute") return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type) -# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior +# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior def shaped_abstractify(x): handler = _shaped_abstractify_handlers.get(type(x), None) return handler(x) if handler is not None else _shaped_abstractify_slow(x) diff --git a/jax/_src/array.py b/jax/_src/array.py index cf2cc6248bd8..f07f3bc963fe 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1029,7 +1029,7 @@ def make_array_from_single_device_arrays( core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval -xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval') +core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): diff --git a/jax/_src/core.py b/jax/_src/core.py index 1250a3f4d954..2ad1f7b0bdf7 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1400,7 +1400,16 @@ def check_valid_jaxtype(x): f"Value {x!r} of type {type(x)} is not a valid JAX type") +# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible. +# This is tricky because concrete_aval includes sharding information, and +# abstractify does not; further, because abstractify is in the dispatch path, +# performance is important and simply adding sharding there is not an option. def concrete_aval(x): + # This differs from abstractify below in that the abstract values + # include sharding where applicable. Historically (before stackless) + # the returned avals were concrete, but after the stackless change + # this returns ShapedArray like abstractify. + # Rules are registered in pytype_aval_mappings. for typ in type(x).__mro__: handler = pytype_aval_mappings.get(typ) if handler: return handler(x) @@ -1410,6 +1419,22 @@ def concrete_aval(x): "type") +def abstractify(x): + # Historically, this was called xla.abstractify. It differs from + # concrete_aval in that it excludes sharding information, and + # uses a more performant path for accessing avals. Rules are + # registered in xla_pytype_aval_mappings. + typ = type(x) + aval_fn = xla_pytype_aval_mappings.get(typ) + if aval_fn: return aval_fn(x) + for typ in typ.__mro__: + aval_fn = xla_pytype_aval_mappings.get(typ) + if aval_fn: return aval_fn(x) + if hasattr(x, '__jax_array__'): + return abstractify(x.__jax_array__()) + raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") + + def get_aval(x): if isinstance(x, Tracer): return x.aval @@ -1810,6 +1835,7 @@ def to_tangent_aval(self): self.weak_type) pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} +xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} class DArray: @@ -1866,6 +1892,7 @@ def data(self): pytype_aval_mappings[DArray] = \ lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) +xla_pytype_aval_mappings[DArray] = lambda x: x._aval @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -1898,6 +1925,7 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval +xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): return mutable_array_p.bind(init_val) @@ -1951,6 +1979,7 @@ def __init__(self, buf): def block_until_ready(self): self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token +xla_pytype_aval_mappings[Token] = lambda _: abstract_token # TODO(dougalm): Deprecate these. They're just here for backwards compat. diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index 7b94d8be63be..f011e756da31 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -457,7 +457,7 @@ def _device_put_impl( " please provide a concrete Sharding with memory_kind.") try: - aval = xla.abstractify(x) + aval = core.abstractify(x) except TypeError as err: raise TypeError( f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index 05260063ddb5..c0899d6bc6a6 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -35,7 +35,6 @@ import opt_einsum import jax -from jax.interpreters import xla from jax._src import config from jax._src import core @@ -1206,7 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool: f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}") core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval -xla.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval +core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval dtypes._weak_types.append(_DimExpr) def _convertible_to_int(p: DimSize) -> bool: diff --git a/jax/_src/interpreters/mlir.py b/jax/_src/interpreters/mlir.py index 0e3fdea02301..5923cfe00ca3 100644 --- a/jax/_src/interpreters/mlir.py +++ b/jax/_src/interpreters/mlir.py @@ -1825,7 +1825,7 @@ def read(v: core.Atom) -> IrValues: def aval(v: core.Atom) -> core.AbstractValue: if type(v) is core.Literal: - return xla.abstractify(v.val) + return core.abstractify(v.val) else: return v.aval diff --git a/jax/_src/interpreters/pxla.py b/jax/_src/interpreters/pxla.py index 8f27bb7a55d3..a1936d213c60 100644 --- a/jax/_src/interpreters/pxla.py +++ b/jax/_src/interpreters/pxla.py @@ -349,7 +349,7 @@ def _emap_apply_fn(*args): donated_invars=donated_invars, is_explicit_global_axis_size=is_explicit_global_axis_size) return _emap_apply_fn - abstract_args = unsafe_map(xla.abstractify, args) + abstract_args = unsafe_map(core.abstractify, args) compiled_fun, fingerprint = parallel_callable( fun, backend, axis_name, axis_size, global_axis_size, devices, name, in_axes, out_axes_thunk, donated_invars, @@ -360,7 +360,7 @@ def _emap_apply_fn(*args): distributed_debug_log(("Running pmapped function", name), ("python function", fun.f), ("devices", devices), - ("abstract args", map(xla.abstractify, args)), + ("abstract args", map(core.abstractify, args)), ("fingerprint", fingerprint)) return compiled_fun @@ -598,7 +598,7 @@ def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]): @property def aval(self): - aval = xla.abstractify(self.val) + aval = core.abstractify(self.val) shard_axes = dict(self.shard_axes) for axis_idx in sorted(shard_axes.values())[::-1]: aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval) @@ -1145,7 +1145,7 @@ def xla_extension_executable(self): @profiler.annotate_function def call(self, *args): # TODO(frostig): do we need to check sharding and sharded avals? - arg_avals = map(xla.abstractify, args) + arg_avals = map(core.abstractify, args) check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info) return self.unsafe_call(*args) # pylint: disable=not-callable @@ -3092,7 +3092,7 @@ def call(self, *args): ref_avals = self._all_args_info.in_avals debug_info = self._all_args_info.debug_info - all_arg_avals = map(xla.abstractify, kept_args) + all_arg_avals = map(core.abstractify, kept_args) check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info) check_array_xla_sharding_layout_match( args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info, diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 46bc7bef7ca7..565ae39492a0 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -146,44 +146,12 @@ def _canonicalize_python_scalar_dtype(typ, x): canonicalize_dtype_handlers[core.DArray] = identity canonicalize_dtype_handlers[core.MutableArray] = identity +# TODO(jakevdp): deprecate and remove this. def abstractify(x) -> Any: - typ = type(x) - aval_fn = pytype_aval_mappings.get(typ) - if aval_fn: return aval_fn(x) - for typ in typ.__mro__: - aval_fn = pytype_aval_mappings.get(typ) - if aval_fn: return aval_fn(x) - if hasattr(x, '__jax_array__'): - return abstractify(x.__jax_array__()) - raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type") - -def _make_abstract_python_scalar(typ, val): - # Note: all python scalar types are weak except bool, because bool only - # comes in a single width. - return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val), - weak_type=typ is not bool) - -def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: - dtype = np.dtype(x) - dtypes.check_valid_dtype(dtype) - return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype)) - -def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: - dtype = x.dtype - dtypes.check_valid_dtype(dtype) - return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) - - -pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {} -pytype_aval_mappings[core.DArray] = lambda x: x._aval -pytype_aval_mappings[core.MutableArray] = lambda x: x._aval -pytype_aval_mappings[core.Token] = lambda _: core.abstract_token -pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar) - for t in numpy_scalar_types) -pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array -pytype_aval_mappings.update( - (t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types) + return core.abstractify(x) +# TODO(jakevdp): deprecate and remove this. +pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings initial_style_primitives: set[core.Primitive] = set() diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index 976cdd4965af..4a08558bdff2 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -1689,7 +1689,7 @@ def _pjit_call_impl_python( ("out_shardings", out_shardings), ("in_layouts", in_layouts), ("out_layouts", out_layouts), - ("abstract args", map(xla.abstractify, args)), + ("abstract args", map(core.abstractify, args)), ("fingerprint", fingerprint)) try: return compiled.unsafe_call(*args), compiled, pgle_profiler diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 2256e12da1d4..f31e38f537ff 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -463,7 +463,7 @@ def __hash__(self) -> int: core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval -xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval +core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 700cb07ca847..9f0a7bf0bce9 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -38,7 +38,6 @@ from jax import sharding from jax import export from jax.experimental.jax2tf import impl_no_xla -from jax.interpreters import xla from jax._src import ad_checkpoint from jax._src import ad_util @@ -1153,7 +1152,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal, else: return val, jax_dtype else: # A constant - jax_dtype = jax_dtype or xla.abstractify(val).dtype + jax_dtype = jax_dtype or core.abstractify(val).dtype # TODO(document): We assume that the value of a constant does not # change through the scope of the function. But it may be an ndarray, ... # JAX has the same problem when generating HLO. diff --git a/jax/experimental/roofline/roofline.py b/jax/experimental/roofline/roofline.py index 42f72f005034..5fb52b778ed8 100644 --- a/jax/experimental/roofline/roofline.py +++ b/jax/experimental/roofline/roofline.py @@ -26,7 +26,6 @@ from jax._src import util from jax._src.api import make_jaxpr from jax._src.interpreters.partial_eval import dce_jaxpr -from jax._src.interpreters.xla import abstractify from jax._src.mesh import AbstractMesh, Mesh from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map from jax.experimental import shard_map @@ -142,14 +141,14 @@ def write(v: core.Var, node: RooflineShape): def read(v: core.Atom) -> RooflineShape: if type(v) is core.Literal: - return RooflineShape.from_aval(abstractify(v.val)) + return RooflineShape.from_aval(core.abstractify(v.val)) else: assert isinstance(v, core.Var) return env[v] def aval(v: core.Atom) -> core.AbstractValue: if type(v) is core.Literal: - return abstractify(v.val) + return core.abstractify(v.val) else: return v.aval diff --git a/tests/lax_test.py b/tests/lax_test.py index 5da58b38aab7..522c396ce5b0 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3959,7 +3959,7 @@ def setUp(self): core.pytype_aval_mappings[FooArray] = \ lambda x: core.ShapedArray(x.shape, FooTy()) xla.canonicalize_dtype_handlers[FooArray] = lambda x: x - xla.pytype_aval_mappings[FooArray] = \ + core.xla_pytype_aval_mappings[FooArray] = \ lambda x: core.ShapedArray(x.shape, FooTy()) pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler mlir._constant_handlers[FooArray] = foo_array_constant_handler @@ -3973,7 +3973,7 @@ def setUp(self): def tearDown(self): del core.pytype_aval_mappings[FooArray] del xla.canonicalize_dtype_handlers[FooArray] - del xla.pytype_aval_mappings[FooArray] + del core.xla_pytype_aval_mappings[FooArray] del mlir._constant_handlers[FooArray] del mlir._lowerings[make_p] del mlir._lowerings[bake_p]