Skip to content

Commit

Permalink
Merge pull request #25456 from jakevdp:xla-abstractify
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707175097
  • Loading branch information
Google-ML-Automation committed Dec 17, 2024
2 parents 7fe2579 + 2c722d9 commit 0fa5419
Show file tree
Hide file tree
Showing 15 changed files with 79 additions and 59 deletions.
4 changes: 2 additions & 2 deletions benchmarks/api_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
from jax import lax
from jax._src.api_util import shaped_abstractify # technically not an api fn
from jax._src.ad_checkpoint import checkpoint # new jax.remat implementation
from jax._src import core
from jax._src.lib import xla_client as xc
from jax.interpreters import xla
from jax._src import array
from jax._src import op_shardings
from jax._src.pjit import pjit_check_aval_sharding
Expand Down Expand Up @@ -427,7 +427,7 @@ def bench_shaped_abstractify(state):

def _run_benchmark_for_xla_abstractify(arg, state):
while state:
xla.abstractify(arg)
core.abstractify(arg)

def bench_xla_abstractify():
_abstractify_args = [
Expand Down
28 changes: 27 additions & 1 deletion jax/_src/abstract_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,23 +50,49 @@ def canonical_concrete_aval(val, weak_type=None):
sharding = core._get_abstract_sharding(val)
return ShapedArray(np.shape(val), dtype, weak_type=weak_type, sharding=sharding)


def masked_array_error(*args, **kwargs):
raise ValueError("numpy masked arrays are not supported as direct inputs to JAX functions. "
"Use arr.filled() to convert the value to a standard numpy array.")

core.pytype_aval_mappings[np.ma.MaskedArray] = masked_array_error

for t in array_types:

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))

core.pytype_aval_mappings[np.ndarray] = canonical_concrete_aval
core.xla_pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array


def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))

for t in numpy_scalar_types:
core.pytype_aval_mappings[t] = canonical_concrete_aval
core.xla_pytype_aval_mappings[t] = _make_shaped_array_for_numpy_scalar

core.literalable_types.update(array_types)


def _make_concrete_python_scalar(t, x):
dtype = dtypes._scalar_type_to_dtype(t, x)
weak_type = dtypes.is_weakly_typed(x)
return canonical_concrete_aval(np.array(x, dtype=dtype), weak_type=weak_type)


def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)

for t in dtypes.python_scalar_dtypes:
core.pytype_aval_mappings[t] = partial(_make_concrete_python_scalar, t)
core.xla_pytype_aval_mappings[t] = partial(_make_abstract_python_scalar, t)

core.literalable_types.update(dtypes.python_scalar_dtypes.keys())
2 changes: 1 addition & 1 deletion jax/_src/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ def _shaped_abstractify_slow(x):
"does not have a dtype attribute")
return core.ShapedArray(np.shape(x), dtype, weak_type=weak_type)

# TODO(mattjj,yashkatariya): replace xla.abstractify with this, same behavior
# TODO(mattjj,yashkatariya): replace core.abstractify with this, same behavior
def shaped_abstractify(x):
handler = _shaped_abstractify_handlers.get(type(x), None)
return handler(x) if handler is not None else _shaped_abstractify_slow(x)
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def make_array_from_single_device_arrays(


core.pytype_aval_mappings[ArrayImpl] = abstract_arrays.canonical_concrete_aval
xla.pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
core.xla_pytype_aval_mappings[ArrayImpl] = op.attrgetter('aval')
xla.canonicalize_dtype_handlers[ArrayImpl] = pxla.identity
def _get_aval_array(self):
if config.sharding_in_types.value and isinstance(self.sharding, NamedSharding):
Expand Down
29 changes: 29 additions & 0 deletions jax/_src/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1400,7 +1400,16 @@ def check_valid_jaxtype(x):
f"Value {x!r} of type {type(x)} is not a valid JAX type")


# TODO(jakevdp): merge concrete_aval and abstractify to the extent possible.
# This is tricky because concrete_aval includes sharding information, and
# abstractify does not; further, because abstractify is in the dispatch path,
# performance is important and simply adding sharding there is not an option.
def concrete_aval(x):
# This differs from abstractify below in that the abstract values
# include sharding where applicable. Historically (before stackless)
# the returned avals were concrete, but after the stackless change
# this returns ShapedArray like abstractify.
# Rules are registered in pytype_aval_mappings.
for typ in type(x).__mro__:
handler = pytype_aval_mappings.get(typ)
if handler: return handler(x)
Expand All @@ -1410,6 +1419,22 @@ def concrete_aval(x):
"type")


def abstractify(x):
# Historically, this was called xla.abstractify. It differs from
# concrete_aval in that it excludes sharding information, and
# uses a more performant path for accessing avals. Rules are
# registered in xla_pytype_aval_mappings.
typ = type(x)
aval_fn = xla_pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = xla_pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")


def get_aval(x):
if isinstance(x, Tracer):
return x.aval
Expand Down Expand Up @@ -1810,6 +1835,7 @@ def to_tangent_aval(self):
self.weak_type)

pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}
xla_pytype_aval_mappings: dict[type, Callable[[Any], AbstractValue]] = {}


class DArray:
Expand Down Expand Up @@ -1866,6 +1892,7 @@ def data(self):

pytype_aval_mappings[DArray] = \
lambda x: DShapedArray(x._aval.shape, x._aval.dtype, x._aval.weak_type)
xla_pytype_aval_mappings[DArray] = lambda x: x._aval

@dataclass(frozen=True)
class bint(dtypes.ExtendedDType):
Expand Down Expand Up @@ -1898,6 +1925,7 @@ def __getitem__(self, idx): return get_aval(self)._getitem(self, idx)
def __setitem__(self, idx, x): return get_aval(self)._setitem(self, idx, x)
def __repr__(self) -> str: return 'Mutable' + repr(self[...])
pytype_aval_mappings[MutableArray] = lambda x: x._aval
xla_pytype_aval_mappings[MutableArray] = lambda x: x._aval

def mutable_array(init_val):
return mutable_array_p.bind(init_val)
Expand Down Expand Up @@ -1951,6 +1979,7 @@ def __init__(self, buf):
def block_until_ready(self):
self._buf.block_until_ready()
pytype_aval_mappings[Token] = lambda _: abstract_token
xla_pytype_aval_mappings[Token] = lambda _: abstract_token


# TODO(dougalm): Deprecate these. They're just here for backwards compat.
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def _device_put_impl(
" please provide a concrete Sharding with memory_kind.")

try:
aval = xla.abstractify(x)
aval = core.abstractify(x)
except TypeError as err:
raise TypeError(
f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
Expand Down
3 changes: 1 addition & 2 deletions jax/_src/export/shape_poly.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
import opt_einsum

import jax
from jax.interpreters import xla

from jax._src import config
from jax._src import core
Expand Down Expand Up @@ -1206,7 +1205,7 @@ def _geq_decision(e1: DimSize, e2: DimSize, cmp_str: Callable[[], str]) -> bool:
f"Symbolic dimension comparison {cmp_str()} is inconclusive.{describe_scope}")

core.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
xla.pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
core.xla_pytype_aval_mappings[_DimExpr] = _DimExpr._get_aval
dtypes._weak_types.append(_DimExpr)

def _convertible_to_int(p: DimSize) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/interpreters/mlir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1825,7 +1825,7 @@ def read(v: core.Atom) -> IrValues:

def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return xla.abstractify(v.val)
return core.abstractify(v.val)
else:
return v.aval

Expand Down
10 changes: 5 additions & 5 deletions jax/_src/interpreters/pxla.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _emap_apply_fn(*args):
donated_invars=donated_invars,
is_explicit_global_axis_size=is_explicit_global_axis_size)
return _emap_apply_fn
abstract_args = unsafe_map(xla.abstractify, args)
abstract_args = unsafe_map(core.abstractify, args)
compiled_fun, fingerprint = parallel_callable(
fun, backend, axis_name, axis_size, global_axis_size, devices, name,
in_axes, out_axes_thunk, donated_invars,
Expand All @@ -360,7 +360,7 @@ def _emap_apply_fn(*args):
distributed_debug_log(("Running pmapped function", name),
("python function", fun.f),
("devices", devices),
("abstract args", map(xla.abstractify, args)),
("abstract args", map(core.abstractify, args)),
("fingerprint", fingerprint))
return compiled_fun

Expand Down Expand Up @@ -598,7 +598,7 @@ def __init__(self, trace: MapTrace, val, shard_axes: dict[core.AxisName, int]):

@property
def aval(self):
aval = xla.abstractify(self.val)
aval = core.abstractify(self.val)
shard_axes = dict(self.shard_axes)
for axis_idx in sorted(shard_axes.values())[::-1]:
aval = core.mapped_aval(aval.shape[axis_idx], axis_idx, aval)
Expand Down Expand Up @@ -1145,7 +1145,7 @@ def xla_extension_executable(self):
@profiler.annotate_function
def call(self, *args):
# TODO(frostig): do we need to check sharding and sharded avals?
arg_avals = map(xla.abstractify, args)
arg_avals = map(core.abstractify, args)
check_arg_avals_for_call(self.in_avals, arg_avals, self._jaxpr_debug_info)
return self.unsafe_call(*args) # pylint: disable=not-callable

Expand Down Expand Up @@ -3092,7 +3092,7 @@ def call(self, *args):
ref_avals = self._all_args_info.in_avals
debug_info = self._all_args_info.debug_info

all_arg_avals = map(xla.abstractify, kept_args)
all_arg_avals = map(core.abstractify, kept_args)
check_arg_avals_for_call(ref_avals, all_arg_avals, debug_info)
check_array_xla_sharding_layout_match(
args_after_dce, self._in_shardings, self._xla_in_layouts, debug_info,
Expand Down
40 changes: 4 additions & 36 deletions jax/_src/interpreters/xla.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,44 +146,12 @@ def _canonicalize_python_scalar_dtype(typ, x):
canonicalize_dtype_handlers[core.DArray] = identity
canonicalize_dtype_handlers[core.MutableArray] = identity

# TODO(jakevdp): deprecate and remove this.
def abstractify(x) -> Any:
typ = type(x)
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
for typ in typ.__mro__:
aval_fn = pytype_aval_mappings.get(typ)
if aval_fn: return aval_fn(x)
if hasattr(x, '__jax_array__'):
return abstractify(x.__jax_array__())
raise TypeError(f"Argument '{x}' of type '{type(x)}' is not a valid JAX type")

def _make_abstract_python_scalar(typ, val):
# Note: all python scalar types are weak except bool, because bool only
# comes in a single width.
return ShapedArray((), dtypes._scalar_type_to_dtype(typ, val),
weak_type=typ is not bool)

def _make_shaped_array_for_numpy_scalar(x: np.generic) -> ShapedArray:
dtype = np.dtype(x)
dtypes.check_valid_dtype(dtype)
return ShapedArray(np.shape(x), dtypes.canonicalize_dtype(dtype))

def _make_shaped_array_for_numpy_array(x: np.ndarray) -> ShapedArray:
dtype = x.dtype
dtypes.check_valid_dtype(dtype)
return ShapedArray(x.shape, dtypes.canonicalize_dtype(dtype))


pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = {}
pytype_aval_mappings[core.DArray] = lambda x: x._aval
pytype_aval_mappings[core.MutableArray] = lambda x: x._aval
pytype_aval_mappings[core.Token] = lambda _: core.abstract_token
pytype_aval_mappings.update((t, _make_shaped_array_for_numpy_scalar)
for t in numpy_scalar_types)
pytype_aval_mappings[np.ndarray] = _make_shaped_array_for_numpy_array
pytype_aval_mappings.update(
(t, partial(_make_abstract_python_scalar, t)) for t in _scalar_types)
return core.abstractify(x)

# TODO(jakevdp): deprecate and remove this.
pytype_aval_mappings: dict[Any, Callable[[Any], core.AbstractValue]] = core.xla_pytype_aval_mappings

initial_style_primitives: set[core.Primitive] = set()

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -1689,7 +1689,7 @@ def _pjit_call_impl_python(
("out_shardings", out_shardings),
("in_layouts", in_layouts),
("out_layouts", out_layouts),
("abstract args", map(xla.abstractify, args)),
("abstract args", map(core.abstractify, args)),
("fingerprint", fingerprint))
try:
return compiled.unsafe_call(*args), compiled, pgle_profiler
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/prng.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,7 @@ def __hash__(self) -> int:


core.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
xla.pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval
core.xla_pytype_aval_mappings[PRNGKeyArray] = lambda x: x.aval

xla.canonicalize_dtype_handlers[PRNGKeyArray] = lambda x: x

Expand Down
3 changes: 1 addition & 2 deletions jax/experimental/jax2tf/jax2tf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
from jax import sharding
from jax import export
from jax.experimental.jax2tf import impl_no_xla
from jax.interpreters import xla

from jax._src import ad_checkpoint
from jax._src import ad_util
Expand Down Expand Up @@ -1153,7 +1152,7 @@ def _tfval_to_tensor_jax_dtype(val: TfVal,
else:
return val, jax_dtype
else: # A constant
jax_dtype = jax_dtype or xla.abstractify(val).dtype
jax_dtype = jax_dtype or core.abstractify(val).dtype
# TODO(document): We assume that the value of a constant does not
# change through the scope of the function. But it may be an ndarray, ...
# JAX has the same problem when generating HLO.
Expand Down
5 changes: 2 additions & 3 deletions jax/experimental/roofline/roofline.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from jax._src import util
from jax._src.api import make_jaxpr
from jax._src.interpreters.partial_eval import dce_jaxpr
from jax._src.interpreters.xla import abstractify
from jax._src.mesh import AbstractMesh, Mesh
from jax._src.tree_util import broadcast_prefix, tree_flatten, tree_unflatten, tree_map
from jax.experimental import shard_map
Expand Down Expand Up @@ -142,14 +141,14 @@ def write(v: core.Var, node: RooflineShape):

def read(v: core.Atom) -> RooflineShape:
if type(v) is core.Literal:
return RooflineShape.from_aval(abstractify(v.val))
return RooflineShape.from_aval(core.abstractify(v.val))
else:
assert isinstance(v, core.Var)
return env[v]

def aval(v: core.Atom) -> core.AbstractValue:
if type(v) is core.Literal:
return abstractify(v.val)
return core.abstractify(v.val)
else:
return v.aval

Expand Down
4 changes: 2 additions & 2 deletions tests/lax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3959,7 +3959,7 @@ def setUp(self):
core.pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
xla.canonicalize_dtype_handlers[FooArray] = lambda x: x
xla.pytype_aval_mappings[FooArray] = \
core.xla_pytype_aval_mappings[FooArray] = \
lambda x: core.ShapedArray(x.shape, FooTy())
pxla.shard_arg_handlers[FooArray] = shard_foo_array_handler
mlir._constant_handlers[FooArray] = foo_array_constant_handler
Expand All @@ -3973,7 +3973,7 @@ def setUp(self):
def tearDown(self):
del core.pytype_aval_mappings[FooArray]
del xla.canonicalize_dtype_handlers[FooArray]
del xla.pytype_aval_mappings[FooArray]
del core.xla_pytype_aval_mappings[FooArray]
del mlir._constant_handlers[FooArray]
del mlir._lowerings[make_p]
del mlir._lowerings[bake_p]
Expand Down

0 comments on commit 0fa5419

Please sign in to comment.