diff --git a/CHANGELOG.md b/CHANGELOG.md index 71fe5ac06556..7822eb8f0a7b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG. ## Unreleased +* Deprecations + * From {mod}`jax.interpreters.xla`, `abstractify` and `pytype_aval_mappings` + are now deprecated, having been replaced by symbols of the same name + in {mod}`jax.core`. + ## jax 0.4.38 (Dec 17, 2024) * Changes: diff --git a/jax/_src/abstract_arrays.py b/jax/_src/abstract_arrays.py index 2094f36dc918..8ddc33fd8983 100644 --- a/jax/_src/abstract_arrays.py +++ b/jax/_src/abstract_arrays.py @@ -49,7 +49,6 @@ def masked_array_error(*args, **kwargs): "Use arr.filled() to convert the value to a standard numpy array.") core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error -core.xla_pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: @@ -58,7 +57,6 @@ def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray: return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype)) core.pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array -core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: @@ -68,7 +66,6 @@ def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray: for t in numpy_scalar_types: core.pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar - core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar core.literalable_types.update(array_types) @@ -81,6 +78,5 @@ def _make_abstract_python_scalar(typ, val): for t in dtypes.python_scalar_dtypes: core.pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) - core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t) core.literalable_types.update(dtypes.python_scalar_dtypes.keys()) diff --git a/jax/_src/array.py b/jax/_src/array.py index 29bcbac58f25..e5d6902d1d1b 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1038,7 +1038,6 @@ def _get_aval_array(self): api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array core.pytype_aval_mappings[ArrayImpl] = _get_aval_array -core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array # TODO(jakevdp) replace this with true inheritance at the C++ level. basearray.Array.register(ArrayImpl) diff --git a/jax/_src/core.py b/jax/_src/core.py index 3acd5b83c60d..59783466b151 100644 --- a/jax/_src/core.py +++ b/jax/_src/core.py @@ -1388,7 +1388,7 @@ def lattice_join(x, y): def valid_jaxtype(x) -> bool: try: - concrete_aval(x) + abstractify(x) except TypeError: return False else: @@ -1400,35 +1400,9 @@ def check_valid_jaxtype(x): f"Value {x!r} of type {type(x)} is not a valid JAX type") -# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible. -# This is tricky because concrete_aval includes sharding information, and -# abstractify does not; further, because abstractify is in the dispatch path, -# performance is important and simply adding sharding there is not an option. -def concrete_aval(x): - # This differs from abstractify below in that the abstract values - # include sharding where applicable. Historically (before stackless) - # the returned avals were concrete, but after the stackless change - # this returns ShapedArray like abstractify. - # Rules are registered in pytype_aval_mappings. - for typ in type(x).__mro__: - handler = pytype_aval_mappings.get(typ) - if handler: return handler(x) - if hasattr(x, '__jax_array__'): - return concrete_aval(x.__jax_array__()) - raise TypeError(f"Value {x!r} with type {type(x)} is not a valid JAX " - "type") - - def abstractify(x): - # Historically, this was called xla.abstractify. It differs from - # concrete_aval in that it excludes sharding information, and - # uses a more performant path for accessing avals. Rules are - # registered in xla_pytype_aval_mappings. - typ = type(x) - aval_fn = xla_pytype_aval_mappings.get(typ) - if aval_fn: return aval_fn(x) - for typ in typ.__mro__: - aval_fn = xla_pytype_aval_mappings.get(typ) + for typ in type(x).__mro__: + aval_fn = pytype_aval_mappings.get(typ) if aval_fn: return aval_fn(x) if hasattr(x, '__jax_array__'): return abstractify(x.__jax_array__()) @@ -1439,7 +1413,7 @@ def get_aval(x): if isinstance(x, Tracer): return x.aval else: - return concrete_aval(x) + return abstractify(x) get_type = get_aval @@ -1835,7 +1809,6 @@ def to_tangent_aval(self): self.weak_type) pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} -xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {} class DArray: @@ -1892,7 +1865,6 @@ def data(self): pytype_aval_mappings[DArray] = \ lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type) -xla_pytype_aval_mappings[DArray] = lambda x: x._aval @dataclass(frozen=True) class bint(dtypes.ExtendedDType): @@ -1925,7 +1897,6 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx) def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x) def __repr__(self) -> str: return 'Mutable' + repr(self[...]) pytype_aval_mappings[MutableArray] = lambda x: x._aval -xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval def mutable_array(init_val): return mutable_array_p.bind(init_val) @@ -1979,7 +1950,6 @@ def __init__(self, buf): def block_until_ready(self): self._buf.block_until_ready() pytype_aval_mappings[Token] = lambda _: abstract_token -xla_pytype_aval_mappings[Token] = lambda _: abstract_token # TODO(dougalm): Deprecate these. They're just here for backwards compat. diff --git a/jax/_src/export/shape_poly.py b/jax/_src/export/shape_poly.py index c0899d6bc6a6..5462723c8335 100644 --- a/jax/_src/export/shape_poly.py +++ b/jax/_src/export/shape_poly.py @@ -1205,7 +1205,6 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool: f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}") core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval -core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval dtypes._weak_types.append(_DimExpr) def _convertible_to_int(p: DimSize) -> bool: diff --git a/jax/_src/interpreters/xla.py b/jax/_src/interpreters/xla.py index 565ae39492a0..33a8992a8be4 100644 --- a/jax/_src/interpreters/xla.py +++ b/jax/_src/interpreters/xla.py @@ -146,13 +146,6 @@ def _canonicalize_python_scalar_dtype(typ, x): canonicalize_dtype_handlers[core.DArray] = identity canonicalize_dtype_handlers[core.MutableArray] = identity -# TODO(jakevdp): deprecate and remove this. -def abstractify(x) -> Any: - return core.abstractify(x) - -# TODO(jakevdp): deprecate and remove this. -pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings - initial_style_primitives: set[core.Primitive] = set() def register_initial_style_primitive(prim: core.Primitive): diff --git a/jax/_src/prng.py b/jax/_src/prng.py index f31e38f537ff..d0f3b644b926 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -463,7 +463,6 @@ def __hash__(self) -> int: core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval -core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x diff --git a/jax/core.py b/jax/core.py index 10a8e00a3b36..54bbdac51c87 100644 --- a/jax/core.py +++ b/jax/core.py @@ -122,7 +122,7 @@ _src_core.call_p), "closed_call_p": ("jax.core.closed_call_p is deprecated. Use jax.extend.core.primitives.closed_call_p", _src_core.closed_call_p), - "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.concrete_aval), + "concrete_aval": ("jax.core.concrete_aval is deprecated.", _src_core.abstractify), "dedup_referents": ("jax.core.dedup_referents is deprecated.", _src_core.dedup_referents), "escaped_tracer_error": ("jax.core.escaped_tracer_error is deprecated.", _src_core.escaped_tracer_error), @@ -207,7 +207,7 @@ axis_frame = _src_core.axis_frame call_p = _src_core.call_p closed_call_p = _src_core.closed_call_p - concrete_aval = _src_core.concrete_aval + concrete_aval = _src_core.abstractify dedup_referents = _src_core.dedup_referents escaped_tracer_error = _src_core.escaped_tracer_error extend_axis_env_nd = _src_core.extend_axis_env_nd diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index b3a470f5e049..bd3b83e37d24 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -13,10 +13,8 @@ # limitations under the License. from jax._src.interpreters.xla import ( - abstractify as abstractify, canonicalize_dtype as canonicalize_dtype, canonicalize_dtype_handlers as canonicalize_dtype_handlers, - pytype_aval_mappings as pytype_aval_mappings, ) from jax._src.dispatch import ( @@ -27,8 +25,19 @@ Backend = _xc._xla.Client del _xc +from jax._src import core as _src_core + # Deprecations _deprecations = { + # Added 2024-12-17 + "abstractify": ( + "jax.interpreters.xla.abstractify is deprecated.", + _src_core.abstractify + ), + "pytype_aval_mappings": ( + "jax.interpreters.xla.pytype_aval_mappings is deprecated.", + _src_core.pytype_aval_mappings + ), # Finalized 2024-10-24; remove after 2025-01-24 "xb": ( ("jax.interpreters.xla.xb was removed in JAX v0.4.36. " @@ -44,6 +53,13 @@ ), } +import typing as _typing from jax._src.deprecations import deprecation_getattr as _deprecation_getattr -__getattr__ = _deprecation_getattr(__name__, _deprecations) +if _typing.TYPE_CHECKING: + abstractify = _src_core.abstractify + pytype_aval_mappings = _src_core.pytype_aval_mappings +else: + __getattr__ = _deprecation_getattr(__name__, _deprecations) del _deprecation_getattr +del _typing +del _src_core diff --git a/tests/lax_test.py b/tests/lax_test.py index 522c396ce5b0..2f2cc76fb1a3 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -3959,8 +3959,6 @@ def setUp(self): core.pytype_aval_mappings[FooArray] = \ lambda x: core.ShapedArray(x.shape, FooTy()) xla.canonicalize_dtype_handlers[FooArray] = lambda x: x - core.xla_pytype_aval_mappings[FooArray] = \ - lambda x: core.ShapedArray(x.shape, FooTy()) pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler mlir._constant_handlers[FooArray] = foo_array_constant_handler mlir.register_lowering(make_p, mlir.lower_fun(make_lowering, False)) @@ -3973,7 +3971,6 @@ def setUp(self): def tearDown(self): del core.pytype_aval_mappings[FooArray] del xla.canonicalize_dtype_handlers[FooArray] - del core.xla_pytype_aval_mappings[FooArray] del mlir._constant_handlers[FooArray] del mlir._lowerings[make_p] del mlir._lowerings[bake_p]