From 704da2153601d02311041ba853c6aac987a32fe7 Mon Sep 17 00:00:00 2001 From: Jake VanderPlas Date: Tue, 17 Dec 2024 13:43:53 -0800 Subject: [PATCH] Always use the same code for array avals --- jax/_src/array.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/jax/_src/array.py b/jax/_src/array.py index f07f3bc963fe..d668deeb11cf 100644 --- a/jax/_src/array.py +++ b/jax/_src/array.py @@ -1027,10 +1027,8 @@ def make_array_from_single_device_arrays( return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays), committed=True) - -core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval -core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval') xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity + def _get_aval_array(self): if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding): return self.aval.update(sharding=NamedSharding( @@ -1038,7 +1036,11 @@ def _get_aval_array(self): self.sharding.spec._normalized_spec(self.ndim))) else: return self.aval + api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array +core.pytype_aval_mappings[ArrayImpl] = _get_aval_array +core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array + # TODO(jakevdp) replace this with true inheritance at the C++ level. basearray.Array.register(ArrayImpl)