Skip to content

Commit

Permalink
Always use the same code for array avals
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Dec 17, 2024
1 parent 05ad393 commit 704da21
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1027,18 +1027,20 @@ def make_array_from_single_device_arrays(
return ArrayImpl(aval, sharding, cast(Sequence[ArrayImpl], arrays),
committed=True)


core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity

def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
return self.aval.update(sharding=NamedSharding(
self.sharding.mesh.abstract_mesh,
self.sharding.spec._normalized_spec(self.ndim)))
else:
return self.aval

api_util._shaped_abstractify_handlers[ArrayImpl] = _get_aval_array
core.pytype_aval_mappings[ArrayImpl] = _get_aval_array
core.xla_pytype_aval_mappings[ArrayImpl] = _get_aval_array

# TODO(jakevdp) replace this with true inheritance at the C++ level.
basearray.Array.register(ArrayImpl)

Expand Down

0 comments on commit 704da21

Please sign in to comment.