diff --git a/build/build_wheel.py b/build/build_wheel.py index f112be41f347..987511194a33 100644 --- a/build/build_wheel.py +++ b/build/build_wheel.py @@ -169,7 +169,7 @@ def prepare_wheel(sources_path): copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py") copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}") copy_to_jaxlib("__main__/jaxlib/lapack.py") - copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py") + copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py") copy_to_jaxlib("__main__/jaxlib/ducc_fft.py") copy_to_jaxlib("__main__/jaxlib/gpu_prng.py") copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py") diff --git a/jax/_src/ad_checkpoint.py b/jax/_src/ad_checkpoint.py index a1b3c41e0093..ed2a8a70f20f 100644 --- a/jax/_src/ad_checkpoint.py +++ b/jax/_src/ad_checkpoint.py @@ -34,7 +34,7 @@ from jax._src.lax import lax as lax_internal from jax._src.lax import convolution as lax_convolution from jax._src.lib import xla_client as xc -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.traceback_util import api_boundary from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map, safe_zip, merge_lists, weakref_lru_cache) @@ -623,9 +623,9 @@ def _optimization_barrier_lowering_rule(ctx, *args): flat_barrier_types = util.flatten(barrier_types) flat_args = mlir.flatten_lowering_ir_args(args) if xc.mlir_api_version < 40: - barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args) + barrier_op = hlo.OptimizationBarrierOp(flat_barrier_types, flat_args) else: - barrier_op = mhlo.OptimizationBarrierOp(flat_args) + barrier_op = hlo.OptimizationBarrierOp(flat_args) return util.unflatten(barrier_op.results, map(len, barrier_types)) def _optimization_barrier(arg): diff --git a/jax/_src/custom_derivatives.py b/jax/_src/custom_derivatives.py index 93f5ca36732e..be004e5d0423 100644 --- a/jax/_src/custom_derivatives.py +++ b/jax/_src/custom_derivatives.py @@ -1205,8 +1205,8 @@ def unreachable_impl(*_, out_avals, exc_type, message): # Translation raises an exception # TODO(frostig,mattjj): We have no good way to translate a function -# that errs. Since MHLO lowering over-approximates concrete evaluation, -# we err on MHLO lowering for the time being. +# that errs. Since MLIR lowering over-approximates concrete evaluation, +# we err on MLIR lowering for the time being. mlir.register_lowering(unreachable_p, unreachable_impl) # Abstract evaluation proceeds without issue, to allow for staging diff --git a/jax/_src/debugging.py b/jax/_src/debugging.py index d515c9aff065..90e050d302ea 100644 --- a/jax/_src/debugging.py +++ b/jax/_src/debugging.py @@ -38,7 +38,7 @@ from jax._src.lax import control_flow as lcf from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo import jax.numpy as jnp import numpy as np @@ -335,15 +335,15 @@ def _hlo_sharding_callback(hlo_sharding): # partitioner runs so we keep it alive by attaching it to the executable. ctx.module_context.add_keepalive(sharding_callback_info) - mhlo.CustomCallOp([value.type], [value], - call_target_name=ir.StringAttr.get( - _INSPECT_SHARDING_CALL_NAME), - has_side_effect=ir.BoolAttr.get(True), - api_version=mlir.i32_attr(1), - called_computations=ir.ArrayAttr.get([]), - backend_config=ir.StringAttr.get(key), - operand_layouts=None, - result_layouts=None) + hlo.CustomCallOp([value.type], [value], + call_target_name=ir.StringAttr.get( + _INSPECT_SHARDING_CALL_NAME), + has_side_effect=ir.BoolAttr.get(True), + api_version=mlir.i32_attr(1), + called_computations=ir.ArrayAttr.get([]), + backend_config=ir.StringAttr.get(key), + operand_layouts=None, + result_layouts=None) return [] mlir.register_lowering(inspect_sharding_p, _inspect_sharding_lowering_rule) diff --git a/jax/_src/dispatch.py b/jax/_src/dispatch.py index e92fe80a3660..ef66c30e7886 100644 --- a/jax/_src/dispatch.py +++ b/jax/_src/dispatch.py @@ -67,7 +67,7 @@ flags.DEFINE_string( 'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''), - help="Path to which HLO/MHLO IR that is emitted by JAX as input to the " + help="Path to which the IR that is emitted by JAX as input to the " "compiler should be dumped as text files. Optional. If omitted, JAX " "will not dump IR.") diff --git a/jax/_src/lax/control_flow/conditionals.py b/jax/_src/lax/control_flow/conditionals.py index c1b002733540..df74981162c7 100644 --- a/jax/_src/lax/control_flow/conditionals.py +++ b/jax/_src/lax/control_flow/conditionals.py @@ -41,7 +41,7 @@ from jax._src.util import (safe_map, extend_name_stack, split_list, partition_list) from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo import numpy as np from jax._src.lax.control_flow.common import ( @@ -806,11 +806,11 @@ def _cond_lowering(ctx, index, *args, branches, linear): *output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)] flat_output_types = util.flatten(output_types) - # mhlo.CaseOp takes a single argument 'index' and the corresponding blocks + # CaseOp takes a single argument 'index' and the corresponding blocks # have no arguments; the computation within the block uses implicit # captures. - case_op = mhlo.CaseOp(flat_output_types, index=index, - num_branches=len(branches)) + case_op = hlo.CaseOp(flat_output_types, index=index, + num_branches=len(branches)) name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond') for i, jaxpr in enumerate(branches): branch = case_op.regions[i].blocks.append() @@ -824,7 +824,7 @@ def _cond_lowering(ctx, index, *args, branches, linear): dim_var_values=ctx.dim_var_values) out_tokens = [tokens_out.get(eff) for eff in ordered_effects] out_vals = [*out_tokens, *out_vals] - mhlo.ReturnOp(util.flatten(out_vals)) + hlo.ReturnOp(util.flatten(out_vals)) tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types)) tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens]) diff --git a/jax/_src/lax/control_flow/loops.py b/jax/_src/lax/control_flow/loops.py index 067f8c9abf8e..6dcddf14b005 100644 --- a/jax/_src/lax/control_flow/loops.py +++ b/jax/_src/lax/control_flow/loops.py @@ -43,7 +43,7 @@ from jax._src.lax import windowed_reductions from jax._src.lib import xla_client as xc from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp from jax._src.traceback_util import api_boundary from jax._src.util import ( @@ -146,7 +146,7 @@ def scan(f, init, xs, length=None): output arrays. (None is actually an empty pytree.) Also unlike that Python version, :func:`~scan` is a JAX primitive and is - lowered to a single XLA While HLO. That makes it useful for reducing + lowered to a single WhileOp. That makes it useful for reducing compilation times for JIT-compiled functions, since native Python loop constructs in an :func:`~jax.jit` function are unrolled, leading to large XLA computations. @@ -1041,7 +1041,7 @@ def while_loop(cond_fun, body_fun, init_val): return val Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered - to a single XLA While HLO. That makes it useful for reducing compilation times + to a single WhileOp. That makes it useful for reducing compilation times for jit-compiled functions, since native Python loop constructs in an ``@jit`` function are unrolled, leading to large XLA computations. @@ -1420,7 +1420,7 @@ def _while_transpose_error(*_, **kwargs): # break # token, x = body(token, x) # ``` -# Unfortunately, with an MHLO while we can't (1) return multiple values +# Unfortunately, with a WhileOp we can't (1) return multiple values # from a `cond` and (2) can't break a while loop. We thus adopt the # following rewrite strategy: # ``` @@ -1471,7 +1471,7 @@ def fun(*args): args = [*tokens, *args] flat_args = mlir.flatten_lowering_ir_args(args) - while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args) + while_op = hlo.WhileOp(flat_loop_carry_types, flat_args) # Loop condition cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types) @@ -1498,12 +1498,12 @@ def fun(*args): tokens_in=mlir.TokenSet(), tokens_out=None) pred, = lax._unary_reduce_lower( - mhlo.OrOp, + hlo.OrOp, lambda dtype: np.array(False, dtype), pred_ctx, pred, axes=tuple(range(len(pred_aval.shape)))) - mhlo.ReturnOp([pred]) + hlo.ReturnOp([pred]) # Loop body body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types) @@ -1531,11 +1531,11 @@ def fun(*args): _map(mlir.ir_constants, cond_jaxpr.consts), *(x + z), dim_var_values=ctx.dim_var_values) new_z = _map( - partial(_pred_bcast_select_mhlo, ctx, pred_aval, body_pred), new_z, z, + partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z, body_jaxpr.out_avals) - mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x), - *util.flatten(y), *util.flatten(new_z)]) + hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x), + *util.flatten(y), *util.flatten(new_z)]) outputs = util.unflatten(while_op.results, _map(len, loop_carry_types)) tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts]) @@ -1566,16 +1566,16 @@ def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts, core.custom_typechecks[while_p] = _while_typecheck -def _pred_bcast_select_mhlo(ctx, +def _pred_bcast_select_hlo(ctx, pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value], ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]: if x_y_aval is core.abstract_token: x, = xs y, = ys if xc.mlir_api_version < 40: - return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result] + return [hlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result] else: - return [mhlo.AfterAllOp([x, y]).result] + return [hlo.AfterAllOp([x, y]).result] else: assert isinstance(x_y_aval, core.ShapedArray), x_y_aval x, = xs @@ -1589,7 +1589,7 @@ def _pred_bcast_select_mhlo(ctx, x_y_shape = x_y_aval.shape bcast_pred = mlir.broadcast_in_dim(ctx, pred, core.DShapedArray(x_y_shape, np.dtype(np.bool_)), broadcast_dimensions=list(range(len(pred_aval.shape)))) - return mhlo.SelectOp(bcast_pred, x, y).results + return hlo.SelectOp(bcast_pred, x, y).results ### fori_loop diff --git a/jax/_src/lax/convolution.py b/jax/_src/lax/convolution.py index c6da685457fd..5f1dc9e80d7d 100644 --- a/jax/_src/lax/convolution.py +++ b/jax/_src/lax/convolution.py @@ -26,7 +26,7 @@ from jax.interpreters import batching from jax.interpreters import mlir from jax._src import util -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client _max = builtins.max @@ -688,7 +688,7 @@ def _conv_general_dilated_lower( return complex_conv(ctx, lhs, rhs) lhs_spec, rhs_spec, out_spec = dimension_numbers - dnums = mhlo.ConvDimensionNumbers.get( + dnums = hlo.ConvDimensionNumbers.get( input_batch_dimension=lhs_spec[0], input_feature_dimension=lhs_spec[1], input_spatial_dimensions=list(lhs_spec[2:]), @@ -703,7 +703,7 @@ def _conv_general_dilated_lower( padding = np.zeros((0, 2), dtype=np.int64) window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims) return [ - mhlo.ConvolutionOp( + hlo.ConvolutionOp( mlir.aval_to_ir_type(aval_out), lhs, rhs, diff --git a/jax/_src/lax/fft.py b/jax/_src/lax/fft.py index 28ecfd449cc7..ccd2aac7ba62 100644 --- a/jax/_src/lax/fft.py +++ b/jax/_src/lax/fft.py @@ -26,7 +26,7 @@ from jax import lax from jax.interpreters import ad from jax.interpreters import batching -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax._src.lib import ducc_fft from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact @@ -104,8 +104,8 @@ def fft_abstract_eval(x, fft_type, fft_lengths): def _fft_lowering(ctx, x, *, fft_type, fft_lengths): return [ - mhlo.FftOp(x, mhlo.FftTypeAttr.get(fft_type.name), - mlir.dense_int_elements(fft_lengths)).result + hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name), + mlir.dense_int_elements(fft_lengths)).result ] @@ -113,7 +113,11 @@ def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths): if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)): raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778") x_aval, = ctx.avals_in - return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type, + if xla_client.mlir_api_version < 41: + return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type, + fft_lengths=fft_lengths)] + else: + return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type, fft_lengths=fft_lengths)] def _naive_rfft(x, fft_lengths): diff --git a/jax/_src/lax/lax.py b/jax/_src/lax/lax.py index 6fa8b8ac5cb2..ca62b2f09959 100644 --- a/jax/_src/lax/lax.py +++ b/jax/_src/lax/lax.py @@ -58,7 +58,7 @@ from jax._src.lib import xla_client from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lax.utils import ( _input_dtype, standard_abstract_eval, @@ -1635,10 +1635,10 @@ def _maybe_broadcast(target_shape, x): squeeze_shape = [x_shape[i] for i in dims] return broadcast_in_dim(reshape(x, squeeze_shape), target_shape, dims) -def broadcast_mhlo( +def broadcast_hlo( aval_out: core.ShapedArray, avals: Sequence[core.ShapedArray], args: Sequence[ir.Value]) -> Sequence[ir.Value]: - """Broadcasts MHLO values with broadcast-compatible shapes to the same shape. + """Broadcasts HLO values with broadcast-compatible shapes to the same shape. """ out = [] for aval, arg in zip(avals, args): @@ -1647,24 +1647,23 @@ def broadcast_mhlo( dims = mlir.dense_int_elements( range(len(aval_out.shape) - len(aval.shape), len(aval_out.shape))) if any(isinstance(d, ir.Value) for d in aval_out.shape): - arg = mhlo.DynamicBroadcastInDimOp( + arg = hlo.DynamicBroadcastInDimOp( mlir.aval_to_ir_type(aval_out), arg, mlir.shape_tensor(aval_out.shape), dims).result else: - arg = mhlo.BroadcastInDimOp( + arg = hlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval.update(shape=aval_out.shape)), arg, dims).result out.append(arg) return out -def _nary_lower_mhlo(op: Callable, ctx, - *args: Union[ir.Value, Sequence[ir.Value]], - explicit_type=False, **params): - """Lowers an elementwise operator to its MHLO/CHLO equivalent. +def _nary_lower_hlo(op: Callable, ctx, + *args: Union[ir.Value, Sequence[ir.Value]], + explicit_type=False, **params): + """Lowers an elementwise operator to its MLIR equivalent. Args: - explicit_type: does the MHLO/CHLO operator require its output type to be - provided? + explicit_type: does the MLIR op require its output type to be provided? """ del params avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out @@ -1696,76 +1695,76 @@ def _substitute_axis_sizes_in_aval( neg_p = standard_unop(_num, 'neg') ad.deflinear2(neg_p, lambda t, operand: [neg(t)]) -mlir.register_lowering(neg_p, partial(_nary_lower_mhlo, mhlo.NegOp)) +mlir.register_lowering(neg_p, partial(_nary_lower_hlo, hlo.NegOp)) sign_p = standard_unop(_num, 'sign') ad.defjvp_zero(sign_p) -def _sign_lower_mhlo(ctx, x): +def _sign_lower_hlo(ctx, x): x_aval, = ctx.avals_in if dtypes.issubdtype(x_aval.dtype, np.unsignedinteger): - return mhlo.SelectOp( - mlir.compare_mhlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ', - 'UNSIGNED').result, + return hlo.SelectOp( + mlir.compare_hlo(x, mlir.full_like_aval(ctx, 0, x_aval), 'EQ', + 'UNSIGNED').result, mlir.full_like_aval(ctx, 0, x_aval), mlir.full_like_aval(ctx, 1, x_aval)).results - return mhlo.SignOp(x).results + return hlo.SignOp(x).results -mlir.register_lowering(sign_p, _sign_lower_mhlo) +mlir.register_lowering(sign_p, _sign_lower_hlo) nextafter_p = standard_naryop([_float, _float], 'nextafter') -mlir.register_lowering(nextafter_p, partial(_nary_lower_mhlo, chlo.NextAfterOp)) +mlir.register_lowering(nextafter_p, partial(_nary_lower_hlo, chlo.NextAfterOp)) floor_p = standard_unop(_float, 'floor') ad.defjvp_zero(floor_p) -mlir.register_lowering(floor_p, partial(_nary_lower_mhlo, mhlo.FloorOp)) +mlir.register_lowering(floor_p, partial(_nary_lower_hlo, hlo.FloorOp)) ceil_p = standard_unop(_float, 'ceil') ad.defjvp_zero(ceil_p) -mlir.register_lowering(ceil_p, partial(_nary_lower_mhlo, mhlo.CeilOp)) +mlir.register_lowering(ceil_p, partial(_nary_lower_hlo, hlo.CeilOp)) round_p = standard_unop(_float, 'round') ad.defjvp_zero(round_p) def _round_lower(ctx, x, *, rounding_method): if rounding_method is RoundingMethod.AWAY_FROM_ZERO: - return mhlo.RoundOp(x).results + return hlo.RoundOp(x).results else: assert rounding_method is RoundingMethod.TO_NEAREST_EVEN - return mhlo.RoundNearestEvenOp(x).results + return hlo.RoundNearestEvenOp(x).results mlir.register_lowering(round_p, _round_lower) is_finite_p = unop(_fixed_dtype(np.bool_), _float, 'is_finite') ad.defjvp_zero(is_finite_p) -mlir.register_lowering(is_finite_p, partial(_nary_lower_mhlo, mhlo.IsFiniteOp)) +mlir.register_lowering(is_finite_p, partial(_nary_lower_hlo, hlo.IsFiniteOp)) exp_p = standard_unop(_float | _complex, 'exp') ad.defjvp2(exp_p, lambda g, ans, x: mul(g, ans)) # For exp_p it is more efficient to use the reconstructed output for the vjp # rule instead of computing it again from the input. -mlir.register_lowering(exp_p, partial(_nary_lower_mhlo, mhlo.ExpOp)) +mlir.register_lowering(exp_p, partial(_nary_lower_hlo, hlo.ExpOp)) log_p = standard_unop(_float | _complex, 'log') ad.defjvp(log_p, lambda g, x: div(g, x)) -mlir.register_lowering(log_p, partial(_nary_lower_mhlo, mhlo.LogOp)) +mlir.register_lowering(log_p, partial(_nary_lower_hlo, hlo.LogOp)) expm1_p = standard_unop(_float | _complex, 'expm1') ad.defjvp2(expm1_p, lambda g, ans, x: mul(g, add(ans, _one(ans)))) -mlir.register_lowering(expm1_p, partial(_nary_lower_mhlo, mhlo.Expm1Op)) +mlir.register_lowering(expm1_p, partial(_nary_lower_hlo, hlo.Expm1Op)) log1p_p = standard_unop(_float | _complex, 'log1p') ad.defjvp(log1p_p, lambda g, x: div(g, add(x, _one(x)))) -mlir.register_lowering(log1p_p, partial(_nary_lower_mhlo, mhlo.Log1pOp)) +mlir.register_lowering(log1p_p, partial(_nary_lower_hlo, hlo.Log1pOp)) tanh_p = standard_unop(_float | _complex, 'tanh') ad.defjvp2(tanh_p, lambda g, ans, x: mul(add(g, mul(g, ans)), sub(_one(x), ans))) -mlir.register_lowering(tanh_p, partial(_nary_lower_mhlo, mhlo.TanhOp)) +mlir.register_lowering(tanh_p, partial(_nary_lower_hlo, hlo.TanhOp)) logistic_p = standard_unop(_float | _complex, 'logistic') ad.defjvp2(logistic_p, lambda g, ans, x: mul(g, mul(ans, sub(_one(ans), ans)))) -# TODO(phawkins): switch to mhlo.logistic lowering; debug numerical problems. -# mlir.register_lowering(logistic_p, partial(_nary_lower_mhlo, mhlo.LogisticOp)) +# TODO(phawkins): switch to LogisticOp lowering; debug numerical problems. +# mlir.register_lowering(logistic_p, partial(_nary_lower_hlo, hlo.LogisticOp)) def logistic_impl(x): one = _const(x, 1) @@ -1776,11 +1775,11 @@ def logistic_impl(x): sin_p = standard_unop(_float | _complex, 'sin') ad.defjvp(sin_p, lambda g, x: mul(g, cos(x))) -mlir.register_lowering(sin_p, partial(_nary_lower_mhlo, mhlo.SineOp)) +mlir.register_lowering(sin_p, partial(_nary_lower_hlo, hlo.SineOp)) cos_p = standard_unop(_float | _complex, 'cos') ad.defjvp(cos_p, lambda g, x: neg(mul(g, sin(x)))) -mlir.register_lowering(cos_p, partial(_nary_lower_mhlo, mhlo.CosineOp)) +mlir.register_lowering(cos_p, partial(_nary_lower_hlo, hlo.CosineOp)) @_upcast_fp16_for_computation def _tan_impl(x): @@ -1788,7 +1787,7 @@ def _tan_impl(x): tan_p = standard_unop(_float | _complex, 'tan') ad.defjvp2(tan_p, lambda g, ans, x: mul(g, _const(x, 1) + square(ans))) -mlir.register_lowering(tan_p, partial(_nary_lower_mhlo, chlo.TanOp)) +mlir.register_lowering(tan_p, partial(_nary_lower_hlo, chlo.TanOp)) def asin_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): @@ -1799,7 +1798,7 @@ def asin_impl(x): asin_p = standard_unop(_float | _complex, 'asin') ad.defjvp(asin_p, lambda g, x: mul(g, rsqrt(_const(x, 1) - square(x)))) -mlir.register_lowering(asin_p, partial(_nary_lower_mhlo, chlo.AsinOp)) +mlir.register_lowering(asin_p, partial(_nary_lower_hlo, chlo.AsinOp)) def acos_impl(x): if dtypes.issubdtype(_dtype(x), np.complexfloating): @@ -1828,35 +1827,35 @@ def atan_impl(x): atan_p = standard_unop(_float | _complex, 'atan') ad.defjvp(atan_p, lambda g, x: div(g, _const(x, 1) + square(x))) -mlir.register_lowering(atan_p, partial(_nary_lower_mhlo, chlo.AtanOp)) +mlir.register_lowering(atan_p, partial(_nary_lower_hlo, chlo.AtanOp)) atan2_p = standard_naryop([_float | _complex, _float | _complex], 'atan2') ad.defjvp(atan2_p, lambda g, x, y: g * (y / (square(x) + square(y))), lambda g, x, y: g * -x / (square(x) + square(y))) -mlir.register_lowering(atan2_p, partial(_nary_lower_mhlo, mhlo.Atan2Op)) +mlir.register_lowering(atan2_p, partial(_nary_lower_hlo, hlo.Atan2Op)) sinh_p = standard_unop(_float | _complex, 'sinh') ad.defjvp(sinh_p, lambda g, x: mul(g, cosh(x))) -mlir.register_lowering(sinh_p, partial(_nary_lower_mhlo, chlo.SinhOp)) +mlir.register_lowering(sinh_p, partial(_nary_lower_hlo, chlo.SinhOp)) cosh_p = standard_unop(_float | _complex, 'cosh') ad.defjvp(cosh_p, lambda g, x: mul(g, sinh(x))) -mlir.register_lowering(cosh_p, partial(_nary_lower_mhlo, chlo.CoshOp)) +mlir.register_lowering(cosh_p, partial(_nary_lower_hlo, chlo.CoshOp)) asinh_p = standard_unop(_float | _complex, 'asinh') ad.defjvp(asinh_p, lambda g, x: mul(g, rsqrt(square(x) + _one(x)))) -mlir.register_lowering(asinh_p, partial(_nary_lower_mhlo, chlo.AsinhOp)) +mlir.register_lowering(asinh_p, partial(_nary_lower_hlo, chlo.AsinhOp)) acosh_p = standard_unop(_float | _complex, 'acosh') ad.defjvp(acosh_p, lambda g, x: mul(g, rsqrt((x - _one(x)) * (x + _one(x))))) -mlir.register_lowering(acosh_p, partial(_nary_lower_mhlo, chlo.AcoshOp)) +mlir.register_lowering(acosh_p, partial(_nary_lower_hlo, chlo.AcoshOp)) atanh_p = standard_unop(_float | _complex, 'atanh') ad.defjvp(atanh_p, lambda g, x: mul(reciprocal(_one(x) + x), div(g, (_one(x) - x)))) -mlir.register_lowering(atanh_p, partial(_nary_lower_mhlo, chlo.AtanhOp)) +mlir.register_lowering(atanh_p, partial(_nary_lower_hlo, chlo.AtanhOp)) regularized_incomplete_beta_p = standard_naryop( [_float, _float, _float], 'regularized_incomplete_beta') @@ -1880,10 +1879,10 @@ def betainc_grad_not_implemented(g, a, b, x): lgamma_p = standard_unop(_float, 'lgamma') ad.defjvp(lgamma_p, lambda g, x: mul(g, digamma(x))) -mlir.register_lowering(lgamma_p, partial(_nary_lower_mhlo, chlo.LgammaOp)) +mlir.register_lowering(lgamma_p, partial(_nary_lower_hlo, chlo.LgammaOp)) digamma_p = standard_unop(_float, 'digamma') -mlir.register_lowering(digamma_p, partial(_nary_lower_mhlo, chlo.DigammaOp)) +mlir.register_lowering(digamma_p, partial(_nary_lower_hlo, chlo.DigammaOp)) igamma_p = standard_naryop([_float, _float], 'igamma') xla.register_translation(igamma_p, partial(_broadcast_translate, xops.Igamma)) @@ -1919,7 +1918,7 @@ def igammac_grada(g, a, x): bessel_i1e_p = standard_unop(_float, 'bessel_i1e') mlir.register_lowering(bessel_i1e_p, - partial(_nary_lower_mhlo, chlo.BesselI1eOp)) + partial(_nary_lower_hlo, chlo.BesselI1eOp)) def _bessel_i1e_jvp(g, y, x): eps = dtypes.finfo(_dtype(x)).eps @@ -1933,12 +1932,12 @@ def _bessel_i1e_jvp(g, y, x): erf_p = standard_unop(_float, 'erf') ad.defjvp(erf_p, lambda g, x: mul(_const(x, 2. / np.sqrt(np.pi)), mul(g, exp(neg(square(x)))))) -mlir.register_lowering(erf_p, partial(_nary_lower_mhlo, chlo.ErfOp)) +mlir.register_lowering(erf_p, partial(_nary_lower_hlo, chlo.ErfOp)) erfc_p = standard_unop(_float, 'erfc') ad.defjvp(erfc_p, lambda g, x: mul(_const(x, -2. / np.sqrt(np.pi)), mul(g, exp(neg(square(x)))))) -mlir.register_lowering(erfc_p, partial(_nary_lower_mhlo, chlo.ErfcOp)) +mlir.register_lowering(erfc_p, partial(_nary_lower_hlo, chlo.ErfcOp)) erf_inv_p = standard_unop(_float, 'erf_inv') ad.defjvp2(erf_inv_p, lambda g, ans, x: mul(_const(x, np.sqrt(np.pi) / 2.), @@ -1947,11 +1946,11 @@ def _bessel_i1e_jvp(g, y, x): real_p = unop(_complex_basetype, _complex, 'real') ad.deflinear2(real_p, lambda t, _: [complex(t, np.zeros((), _dtype(t)))]) -mlir.register_lowering(real_p, partial(_nary_lower_mhlo, mhlo.RealOp)) +mlir.register_lowering(real_p, partial(_nary_lower_hlo, hlo.RealOp)) imag_p = unop(_complex_basetype, _complex, 'imag') ad.deflinear2(imag_p, lambda t, _: [complex(np.zeros((), _dtype(t)), neg(t))]) -mlir.register_lowering(imag_p, partial(_nary_lower_mhlo, mhlo.ImagOp)) +mlir.register_lowering(imag_p, partial(_nary_lower_hlo, hlo.ImagOp)) def _complex_transpose_rule(t, x, y): @@ -1976,7 +1975,7 @@ def _complex_transpose_rule(t, x, y): complex_p = naryop(_complex_dtype, [_complex_elem_types, _complex_elem_types], 'complex') ad.deflinear2(complex_p, _complex_transpose_rule) -mlir.register_lowering(complex_p, partial(_nary_lower_mhlo, mhlo.ComplexOp)) +mlir.register_lowering(complex_p, partial(_nary_lower_hlo, hlo.ComplexOp)) conj_p = unop(_complex_dtype, _complex_elem_types | _complex, 'conj') @@ -2001,7 +2000,7 @@ def _conj_transpose_rule(t, x, *, input_dtype): ad.primitive_transposes[conj_p] = _conj_transpose_rule abs_p = unop(_complex_basetype, _num, 'abs') -mlir.register_lowering(abs_p, partial(_nary_lower_mhlo, mhlo.AbsOp)) +mlir.register_lowering(abs_p, partial(_nary_lower_hlo, hlo.AbsOp)) def _abs_jvp_rule(g, ans, x): if _iscomplex(x): @@ -2015,18 +2014,18 @@ def _abs_jvp_rule(g, ans, x): sqrt_p = standard_unop(_float | _complex, 'sqrt') ad.defjvp2(sqrt_p, lambda g, ans, x: mul(g, div(_const(x, 0.5), ans))) -mlir.register_lowering(sqrt_p, partial(_nary_lower_mhlo, mhlo.SqrtOp)) +mlir.register_lowering(sqrt_p, partial(_nary_lower_hlo, hlo.SqrtOp)) rsqrt_p = standard_unop(_float | _complex, 'rsqrt') ad.defjvp2(rsqrt_p, lambda g, ans, x: mul(g, mul(_const(x, -0.5), div(ans, x)))) -mlir.register_lowering(rsqrt_p, partial(_nary_lower_mhlo, mhlo.RsqrtOp)) +mlir.register_lowering(rsqrt_p, partial(_nary_lower_hlo, hlo.RsqrtOp)) cbrt_p = standard_unop(_float, 'cbrt') ad.defjvp2(cbrt_p, lambda g, ans, x: mul(g, mul(_const(x, 1/3), integer_pow(ans, -2)))) -mlir.register_lowering(cbrt_p, partial(_nary_lower_mhlo, mhlo.CbrtOp)) +mlir.register_lowering(cbrt_p, partial(_nary_lower_hlo, hlo.CbrtOp)) pow_p = standard_naryop([_float | _complex, _float | _complex], 'pow') @@ -2037,7 +2036,7 @@ def _pow_jvp_rhs(g, ans, x, y): return mul(g, mul(log(_replace_zero(x)), ans)) ad.defjvp2(pow_p, _pow_jvp_lhs, _pow_jvp_rhs) -mlir.register_lowering(pow_p, partial(_nary_lower_mhlo, mhlo.PowOp)) +mlir.register_lowering(pow_p, partial(_nary_lower_hlo, hlo.PowOp)) def _integer_pow_dtype_rule(x, *, y): @@ -2077,8 +2076,8 @@ def _integer_pow(x, *, y): def _integer_pow_lowering(ctx, x, *, y): lowering = mlir.lower_fun(_integer_pow, multiple_results=False) # TODO(b/217551391): emitting an out-of-line call leads to a large - # expansion when the MHLO is lowered to HLO, because the HLO lowering - # clones the callee. Consider unconditionally caching when the MHLO->HLO + # expansion when the MLIR is lowered to HLO, because the HLO lowering + # clones the callee. Consider unconditionally caching when the MLIR->HLO # lowering doesn't expand the program. if y >= 4: lowering = mlir.cache_lowering(lowering) @@ -2090,26 +2089,26 @@ def _integer_pow_lowering(ctx, x, *, y): not_p = standard_unop(_bool_or_int, 'not') ad.defjvp_zero(not_p) -mlir.register_lowering(not_p, partial(_nary_lower_mhlo, mhlo.NotOp)) +mlir.register_lowering(not_p, partial(_nary_lower_hlo, hlo.NotOp)) and_p = standard_naryop([_bool_or_int, _bool_or_int], 'and') ad.defjvp_zero(and_p) -mlir.register_lowering(and_p, partial(_nary_lower_mhlo, mhlo.AndOp)) +mlir.register_lowering(and_p, partial(_nary_lower_hlo, hlo.AndOp)) or_p = standard_naryop([_bool_or_int, _bool_or_int], 'or') ad.defjvp_zero(or_p) -mlir.register_lowering(or_p, partial(_nary_lower_mhlo, mhlo.OrOp)) +mlir.register_lowering(or_p, partial(_nary_lower_hlo, hlo.OrOp)) xor_p = standard_naryop([_bool_or_int, _bool_or_int], 'xor') ad.defjvp_zero(xor_p) -mlir.register_lowering(xor_p, partial(_nary_lower_mhlo, mhlo.XorOp)) +mlir.register_lowering(xor_p, partial(_nary_lower_hlo, hlo.XorOp)) population_count_p = standard_unop(_int, 'population_count') mlir.register_lowering(population_count_p, - partial(_nary_lower_mhlo, mhlo.PopulationCountOp)) + partial(_nary_lower_hlo, hlo.PopulationCountOp)) clz_p = standard_unop(_int, 'clz') -mlir.register_lowering(clz_p, partial(_nary_lower_mhlo, mhlo.ClzOp)) +mlir.register_lowering(clz_p, partial(_nary_lower_hlo, hlo.ClzOp)) def _add_jvp(primals, tangents): x, y = primals @@ -2145,7 +2144,7 @@ def _add_inverse(r, x, y): add_p: Primitive = standard_naryop([_num, _num], 'add') ad.primitive_jvps[add_p] = _add_jvp ad.primitive_transposes[add_p] = _add_transpose -mlir.register_lowering(add_p, partial(_nary_lower_mhlo, mhlo.AddOp)) +mlir.register_lowering(add_p, partial(_nary_lower_hlo, hlo.AddOp)) def _sub_jvp(primals, tangents): x, y = primals @@ -2174,7 +2173,7 @@ def _sub_transpose(t, x, y): sub_p = standard_naryop([_num, _num], 'sub') ad.primitive_jvps[sub_p] = _sub_jvp ad.primitive_transposes[sub_p] = _sub_transpose -mlir.register_lowering(sub_p, partial(_nary_lower_mhlo, mhlo.SubtractOp)) +mlir.register_lowering(sub_p, partial(_nary_lower_hlo, hlo.SubtractOp)) def _mul_transpose(ct, x, y): @@ -2200,7 +2199,7 @@ def _mul_inverse(r, x, y): lambda xdot, x, y: mul(xdot, y), lambda ydot, x, y: mul(x, ydot)) ad.primitive_transposes[mul_p] = _mul_transpose -mlir.register_lowering(mul_p, partial(_nary_lower_mhlo, mhlo.MulOp)) +mlir.register_lowering(mul_p, partial(_nary_lower_hlo, hlo.MulOp)) def _div_transpose_rule(cotangent, x, y): assert ad.is_undefined_primal(x) and not ad.is_undefined_primal(y) @@ -2213,14 +2212,14 @@ def _div_transpose_rule(cotangent, x, y): lambda g, x, y: div(g, y), lambda g, x, y: mul(mul(neg(g), x), integer_pow(y, -2))) ad.primitive_transposes[div_p] = _div_transpose_rule -mlir.register_lowering(div_p, partial(_nary_lower_mhlo, mhlo.DivOp)) +mlir.register_lowering(div_p, partial(_nary_lower_hlo, hlo.DivOp)) rem_p = standard_naryop([_int | _float, _int | _float], 'rem') ad.defjvp( rem_p, lambda g, x, y: _maybe_broadcast(broadcast_shapes(np.shape(x), np.shape(y)), g), lambda g, x, y: mul(neg(g), mul(sign(div(x, y)), floor(abs(div(x, y)))))) -mlir.register_lowering(rem_p, partial(_nary_lower_mhlo, mhlo.RemOp)) +mlir.register_lowering(rem_p, partial(_nary_lower_hlo, hlo.RemOp)) def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): result_shape = broadcast_shapes(np.shape(x), np.shape(y)) @@ -2236,29 +2235,29 @@ def _minmax_complex_lowering(x, y, *, lax_cmp_pick_x): ad.defjvp2(max_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) -mlir.register_lowering(max_p, partial(_nary_lower_mhlo, mlir.max_mhlo)) +mlir.register_lowering(max_p, partial(_nary_lower_hlo, mlir.max_hlo)) min_p: core.Primitive = standard_naryop([_any, _any], 'min') ad.defjvp2(min_p, lambda g, ans, x, y: mul(g, _balanced_eq(x, ans, y)), lambda g, ans, x, y: mul(g, _balanced_eq(y, ans, x))) -mlir.register_lowering(min_p, partial(_nary_lower_mhlo, mlir.min_mhlo)) +mlir.register_lowering(min_p, partial(_nary_lower_hlo, mlir.min_hlo)) shift_left_p = standard_naryop([_int, _int], 'shift_left') ad.defjvp_zero(shift_left_p) -mlir.register_lowering(shift_left_p, partial(_nary_lower_mhlo, mhlo.ShiftLeftOp)) +mlir.register_lowering(shift_left_p, partial(_nary_lower_hlo, hlo.ShiftLeftOp)) shift_right_arithmetic_p = standard_naryop([_int, _int], 'shift_right_arithmetic') ad.defjvp_zero(shift_right_arithmetic_p) mlir.register_lowering(shift_right_arithmetic_p, - partial(_nary_lower_mhlo, mhlo.ShiftRightArithmeticOp)) + partial(_nary_lower_hlo, hlo.ShiftRightArithmeticOp)) shift_right_logical_p = standard_naryop([_int, _int], 'shift_right_logical') ad.defjvp_zero(shift_right_logical_p) mlir.register_lowering(shift_right_logical_p, - partial(_nary_lower_mhlo, mhlo.ShiftRightLogicalOp)) + partial(_nary_lower_hlo, hlo.ShiftRightLogicalOp)) -def _compare_lower_mhlo(direction: str, ctx, x, y): +def _compare_lower_hlo(direction: str, ctx, x, y): avals_in, (aval_out,) = ctx.avals_in, ctx.avals_out x_dtype = avals_in[0].dtype x, y = mlir.multi_broadcast_in_dim(ctx, (x, y), avals_in, aval_out.shape) @@ -2269,31 +2268,31 @@ def _compare_lower_mhlo(direction: str, ctx, x, y): compare_type = "SIGNED" else: compare_type = "UNSIGNED" - return mlir.compare_mhlo(x, y, direction, compare_type).results + return mlir.compare_hlo(x, y, direction, compare_type).results eq_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'eq') ad.defjvp_zero(eq_p) -mlir.register_lowering(eq_p, partial(_compare_lower_mhlo, "EQ")) +mlir.register_lowering(eq_p, partial(_compare_lower_hlo, "EQ")) ne_p = naryop(_fixed_dtype(np.bool_), [_any, _any], 'ne') ad.defjvp_zero(ne_p) -mlir.register_lowering(ne_p, partial(_compare_lower_mhlo, "NE")) +mlir.register_lowering(ne_p, partial(_compare_lower_hlo, "NE")) ge_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'ge') ad.defjvp_zero(ge_p) -mlir.register_lowering(ge_p, partial(_compare_lower_mhlo, "GE")) +mlir.register_lowering(ge_p, partial(_compare_lower_hlo, "GE")) gt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'gt') ad.defjvp_zero(gt_p) -mlir.register_lowering(gt_p, partial(_compare_lower_mhlo, "GT")) +mlir.register_lowering(gt_p, partial(_compare_lower_hlo, "GT")) le_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'le') ad.defjvp_zero(le_p) -mlir.register_lowering(le_p, partial(_compare_lower_mhlo, "LE")) +mlir.register_lowering(le_p, partial(_compare_lower_hlo, "LE")) lt_p = naryop(_fixed_dtype(np.bool_), [_ordered, _ordered], 'lt') ad.defjvp_zero(lt_p) -mlir.register_lowering(lt_p, partial(_compare_lower_mhlo, "LT")) +mlir.register_lowering(lt_p, partial(_compare_lower_hlo, "LT")) def _convert_element_type_shape_rule(operand, *, new_dtype, weak_type): @@ -2392,9 +2391,9 @@ def _convert_element_type_lower(ctx, operand, *, new_dtype, weak_type): aval_out, = ctx.avals_out if (dtypes.issubdtype(aval_in.dtype, np.complexfloating) and not dtypes.issubdtype(new_dtype, np.complexfloating)): - operand = mhlo.RealOp(operand).result + operand = hlo.RealOp(operand).result aval_in = aval_in.update(dtype=_real_dtype(aval_in.dtype)) - return [mlir.convert_mhlo(ctx, operand, aval_in, aval_out)] + return [mlir.convert_hlo(ctx, operand, aval_in, aval_out)] mlir.register_lowering(convert_element_type_p, _convert_element_type_lower) @@ -2419,7 +2418,7 @@ def _bitcast_convert_type_dtype_rule(operand, *, new_dtype): def _bitcast_convert_type_lower(ctx, operand, *, new_dtype): aval_out, = ctx.avals_out - return mhlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results + return hlo.BitcastConvertOp(mlir.aval_to_ir_type(aval_out), operand).results mlir.register_lowering(bitcast_convert_type_p, _bitcast_convert_type_lower) @@ -2708,7 +2707,7 @@ def precision_attr(precision: PrecisionType) -> ir.ArrayAttr: else: full_precision = precision return ir.ArrayAttr.get( - [mhlo.PrecisionAttr.get(str(p)) for p in full_precision]) + [hlo.PrecisionAttr.get(str(p)) for p in full_precision]) @@ -2723,19 +2722,19 @@ def _dot_general_lower(ctx, lhs, rhs, *, dimension_numbers, if ctx.module_context.platform == "cpu": if lhs_aval.dtype == np.float16: f32 = mlir.dtype_to_ir_type(np.dtype(np.float32)) - lhs = mhlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32), - lhs).result + lhs = hlo.ConvertOp(ir.RankedTensorType.get(lhs_aval.shape, f32), + lhs).result if rhs_aval.dtype == np.float16: f32 = mlir.dtype_to_ir_type(np.dtype(np.float32)) - rhs = mhlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32), - rhs).result - dot_dnums = mhlo.DotDimensionNumbers.get( + rhs = hlo.ConvertOp(ir.RankedTensorType.get(rhs_aval.shape, f32), + rhs).result + dot_dnums = hlo.DotDimensionNumbers.get( lhs_batching_dimensions=list(lhs_batch), rhs_batching_dimensions=list(rhs_batch), lhs_contracting_dimensions=list(lhs_contracting), rhs_contracting_dimensions=list(rhs_contracting)) return [ - mhlo.DotGeneralOp( + hlo.DotGeneralOp( mlir.aval_to_ir_type(aval_out), lhs, rhs, @@ -3005,7 +3004,7 @@ def _clamp_batch_rule(batched_args, batch_dims, **params): select(lt(max, operand), g, _zeros(operand))) batching.primitive_batchers[clamp_p] = _clamp_batch_rule mlir.register_lowering( - clamp_p, partial(_nary_lower_mhlo, mhlo.ClampOp)) + clamp_p, partial(_nary_lower_hlo, hlo.ClampOp)) pe.def_trivial_padding(clamp_p) def _concatenate_shape_rule(*operands, **kwargs): @@ -3083,7 +3082,7 @@ def _concatenate_pad_rule(in_avals, out_avals, *operands, dimension): pe.padding_rules[concatenate_p] = _concatenate_pad_rule def _concatenate_lower(ctx, *xs, dimension): - return mhlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results + return hlo.ConcatenateOp(xs, mlir.i64_attr(dimension)).results mlir.register_lowering(concatenate_p, _concatenate_lower) @@ -3158,10 +3157,10 @@ def _pad_batch_rule(batched_args, batch_dims, *, padding_config): def _pad_lower(ctx, x, padding_value, *, padding_config): low, high, interior = util.unzip3(padding_config) - return mhlo.PadOp(x, padding_value, - mlir.dense_int_elements(low), - mlir.dense_int_elements(high), - mlir.dense_int_elements(interior)).results + return hlo.PadOp(x, padding_value, + mlir.dense_int_elements(low), + mlir.dense_int_elements(high), + mlir.dense_int_elements(interior)).results mlir.register_lowering(pad_p, _pad_lower) @@ -3302,7 +3301,7 @@ def _reshape_batch_rule(batched_args, batch_dims, *, new_sizes, dimensions): def _reshape_lower(ctx, x, *dyn_shape, new_sizes, dimensions): aval_out, = ctx.avals_out if dimensions is not None: - x = mhlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result + x = hlo.TransposeOp(x, mlir.dense_int_elements(dimensions)).result if dyn_shape: aval_out = aval_out.update(shape=_merge_dyn_shape(new_sizes, dyn_shape)) return [mlir.reshape(ctx, x, aval_out)] @@ -3346,7 +3345,7 @@ def _rev_batch_rule(batched_args, batch_dims, *, dimensions): batching.primitive_batchers[rev_p] = _rev_batch_rule def _rev_lower(ctx, x, *, dimensions): - return mhlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results + return hlo.ReverseOp(x, mlir.dense_int_elements(dimensions)).results mlir.register_lowering(rev_p, _rev_lower) @@ -3370,7 +3369,7 @@ def _transpose_lower(ctx, x, *, permutation): aval_out, = ctx.avals_out if core.is_opaque_dtype(aval_out.dtype): return [aval_out.dtype._rules.transpose_mlir(ctx, aval_out, x, permutation=permutation)] - return mhlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results + return hlo.TransposeOp(x, mlir.dense_int_elements(permutation)).results transpose_p = standard_primitive(_transpose_shape_rule, _input_dtype, 'transpose') @@ -3470,12 +3469,12 @@ def _select_jvp(primals, tangents): out_dot = select_n(which, *case_tangents) return out, out_dot -def _select_mhlo_lowering(ctx, which, *cases): +def _select_hlo_lowering(ctx, which, *cases): which_aval = ctx.avals_in[0] if which_aval.dtype == np.dtype(np.bool_): assert len(cases) <= 2 if len(cases) == 1: return cases - return mhlo.SelectOp(which, cases[1], cases[0]).results + return hlo.SelectOp(which, cases[1], cases[0]).results if dtypes.issubdtype(which_aval.dtype, np.signedinteger): compare_type = 'SIGNED' @@ -3488,11 +3487,11 @@ def _select(offset, cases): if len(cases) == 1: return cases[0] mid = len(cases) // 2 - pred = mlir.compare_mhlo(which, - mlir.full_like_aval(ctx, offset + mid, which_aval), - lt, compare_type) - return mhlo.SelectOp(pred, _select(offset, cases[:mid]), - _select(offset + mid, cases[mid:])).result + pred = mlir.compare_hlo(which, + mlir.full_like_aval(ctx, offset + mid, which_aval), + lt, compare_type) + return hlo.SelectOp(pred, _select(offset, cases[:mid]), + _select(offset + mid, cases[mid:])).result return [_select(0, cases)] @@ -3502,7 +3501,7 @@ def _select(offset, cases): ad.primitive_jvps[select_n_p] = _select_jvp ad.primitive_transposes[select_n_p] = _select_transpose_rule batching.primitive_batchers[select_n_p] = _select_batch_rule -mlir.register_lowering(select_n_p, _select_mhlo_lowering) +mlir.register_lowering(select_n_p, _select_hlo_lowering) pe.def_trivial_padding(select_n_p) @@ -3622,8 +3621,8 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions): assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in operands, init_values = util.split_list(values, [len(values) // 2]) init_value_avals = ctx.avals_in[len(values) // 2:] - op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - operands, init_values, mlir.dense_int_elements(dimensions)) + op = hlo.ReduceOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + operands, init_values, mlir.dense_int_elements(dimensions)) ir_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] reducer = op.regions[0].blocks.append(*(ir_types + ir_types)) with ir.InsertionPoint(reducer): @@ -3633,7 +3632,7 @@ def _reduce_lower(ctx, *values, computation, jaxpr, consts, dimensions): out_nodes, _ = mlir.jaxpr_subcomp(reducer_ctx, jaxpr, mlir.TokenSet(), consts, *([a] for a in reducer.arguments), dim_var_values=ctx.dim_var_values) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) return op.results mlir.register_lowering(reduce_p, _reduce_lower) @@ -3818,29 +3817,29 @@ def _reduce_logical_shape_rule(operand, *, axes): def _unary_reduce_lower(reducer, unit_factory, ctx, x, *, axes): aval_out, = ctx.avals_out dtype = aval_out.dtype - op = mhlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x], - mlir.ir_constants(unit_factory(aval_out.dtype)), - mlir.dense_int_elements(axes)) + op = hlo.ReduceOp([mlir.aval_to_ir_type(aval_out)], [x], + mlir.ir_constants(unit_factory(aval_out.dtype)), + mlir.dense_int_elements(axes)) scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), dtype)) reducer_region = op.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer_region): add = reducer(*reducer_region.arguments) - mhlo.ReturnOp(add.results) + hlo.ReturnOp(add.results) return op.results -mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, mhlo.AddOp, +mlir.register_lowering(reduce_sum_p, partial(_unary_reduce_lower, hlo.AddOp, _get_sum_identity)) -mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, mhlo.MulOp, +mlir.register_lowering(reduce_prod_p, partial(_unary_reduce_lower, hlo.MulOp, _get_prod_identity)) -mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, mhlo.OrOp, +mlir.register_lowering(reduce_or_p, partial(_unary_reduce_lower, hlo.OrOp, _get_bitwise_or_identity)) -mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, mhlo.AndOp, +mlir.register_lowering(reduce_and_p, partial(_unary_reduce_lower, hlo.AndOp, _get_bitwise_and_identity)) -mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, mhlo.XorOp, +mlir.register_lowering(reduce_xor_p, partial(_unary_reduce_lower, hlo.XorOp, _get_bitwise_or_identity)) -mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_mhlo, +mlir.register_lowering(reduce_min_p, partial(_unary_reduce_lower, mlir.min_hlo, _get_min_identity)) -mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_mhlo, +mlir.register_lowering(reduce_max_p, partial(_unary_reduce_lower, mlir.max_hlo, _get_max_identity)) @@ -3864,8 +3863,8 @@ def _reduce_precision_shape_rule(operand, *, exponent_bits, mantissa_bits): def _reduce_precision_lower(ctx, operand, *, exponent_bits, mantissa_bits): aval_out, = ctx.avals_out - return mhlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits), - mlir.i32_attr(mantissa_bits)).results + return hlo.ReducePrecisionOp(operand, mlir.i32_attr(exponent_bits), + mlir.i32_attr(mantissa_bits)).results mlir.register_lowering(reduce_precision_p, _reduce_precision_lower) @@ -4014,10 +4013,10 @@ def _sort_batch_rule(batched_args, batch_dims, *, dimension, is_stable, num_keys def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): assert all(isinstance(x, core.ShapedArray) for x in ctx.avals_in), ctx.avals_in - sort = mhlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], - mlir.flatten_lowering_ir_args(operands), - dimension=mlir.i64_attr(dimension), - is_stable=ir.BoolAttr.get(is_stable)) + sort = hlo.SortOp([mlir.aval_to_ir_type(aval) for aval in ctx.avals_out], + mlir.flatten_lowering_ir_args(operands), + dimension=mlir.i64_attr(dimension), + is_stable=ir.BoolAttr.get(is_stable)) scalar_avals = [aval.update(shape=()) for aval in ctx.avals_in] scalar_types = safe_map(mlir.aval_to_ir_type, scalar_avals) comparator = sort.comparator.blocks.append( @@ -4031,7 +4030,7 @@ def _sort_lower(ctx, *operands, dimension, is_stable, num_keys): out = lower_comparator(sub_ctx, *[[a] for a in comparator.arguments], num_keys=num_keys) - mhlo.ReturnOp(util.flatten(out)) + hlo.ReturnOp(util.flatten(out)) return sort.results mlir.register_lowering(sort_p, _sort_lower) @@ -4134,10 +4133,9 @@ def create_token(_=None): def _create_token_lowering(ctx, *operands): aval_out, = ctx.avals_out if xc.mlir_api_version < 40: - return mhlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results + return hlo.CreateTokenOp(mlir.aval_to_ir_type(aval_out)).results else: - return mhlo.CreateTokenOp().results - + return hlo.CreateTokenOp().results mlir.register_lowering(create_token_p, _create_token_lowering) @@ -4160,10 +4158,9 @@ def _after_all_abstract_eval(*operands): def _after_all_lowering(ctx, *operands): aval_out, = ctx.avals_out if xc.mlir_api_version < 40: - return mhlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results + return hlo.AfterAllOp(mlir.aval_to_ir_type(aval_out), operands).results else: - return mhlo.AfterAllOp(operands).results - + return hlo.AfterAllOp(operands).results mlir.register_lowering(after_all_p, _after_all_lowering) @@ -4215,8 +4212,8 @@ def _infeed_lowering(ctx, token, *, shapes, partitions): for i in range(len(aval.shape) - 1, -1, -1)]) for aval in shapes ]) - infeed = mhlo.InfeedOp( - flat_output_types + [mhlo.TokenType.get()], + infeed = hlo.InfeedOp( + flat_output_types + [hlo.TokenType.get()], token, infeed_config=ir.StringAttr.get(''), layout=layouts) @@ -4259,13 +4256,13 @@ def _outfeed_abstract_eval(token, *xs, partitions): def _outfeed_lowering(ctx, token, *xs, partitions): token_aval = ctx.avals_in[0] if xc.mlir_api_version < 40: - outfeed = mhlo.OutfeedOp( + outfeed = hlo.OutfeedOp( mlir.aval_to_ir_type(token_aval), mlir.flatten_lowering_ir_args(xs), token, outfeed_config=ir.StringAttr.get('')) else: - outfeed = mhlo.OutfeedOp( + outfeed = hlo.OutfeedOp( mlir.flatten_lowering_ir_args(xs), token, outfeed_config=ir.StringAttr.get('')) @@ -4308,8 +4305,8 @@ def _rng_uniform_lowering(ctx, a, b, *, shape): aval_out, = ctx.avals_out shape, = mlir.ir_constants(np.array(aval_out.shape, np.int64), canonicalize_types=False) - return mhlo.RngOp(a, b, shape, - mhlo.RngDistributionAttr.get('UNIFORM')).results + return hlo.RngOp(a, b, shape, + hlo.RngDistributionAttr.get('UNIFORM')).results mlir.register_lowering(rng_uniform_p, _rng_uniform_lowering) @@ -4331,11 +4328,11 @@ def _rng_bit_generator_weak_type_rule(key, *, shape, dtype, algorithm): def _rng_algorithm(algorithm: RandomAlgorithm): if algorithm == RandomAlgorithm.RNG_THREE_FRY: - return mhlo.RngAlgorithmAttr.get("THREE_FRY") + return hlo.RngAlgorithmAttr.get("THREE_FRY") elif algorithm == RandomAlgorithm.RNG_PHILOX: - return mhlo.RngAlgorithmAttr.get("PHILOX") + return hlo.RngAlgorithmAttr.get("PHILOX") elif algorithm == RandomAlgorithm.RNG_DEFAULT: - return mhlo.RngAlgorithmAttr.get("DEFAULT") + return hlo.RngAlgorithmAttr.get("DEFAULT") else: assert False @@ -4362,21 +4359,21 @@ def _rng_bit_generator_lowering( else: rbg_etype = u32_type if key_etype == u32_type: - key = mhlo.BitcastConvertOp( + key = hlo.BitcastConvertOp( ir.RankedTensorType.get([2], u64_type), - mhlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result + hlo.ReshapeOp(ir.RankedTensorType.get([2, 2], u32_type), key)).result algorithm_attr = _rng_algorithm(algorithm) - out_key, out_vals = mhlo.RngBitGeneratorOp( + out_key, out_vals = hlo.RngBitGeneratorOp( key.type, ir.RankedTensorType.get(shape, rbg_etype), algorithm_attr, key).results if key_etype == u32_type: - out_key = mhlo.ReshapeOp( + out_key = hlo.ReshapeOp( ir.RankedTensorType.get([4], u32_type), - mhlo.BitcastConvertOp( + hlo.BitcastConvertOp( ir.RankedTensorType.get([2, 2], u32_type), out_key)).result if rbg_etype != etype: - out_vals = mhlo.ConvertOp( + out_vals = hlo.ConvertOp( ir.RankedTensorType.get(ir.RankedTensorType(out_vals.type).shape, etype), out_vals).result return [out_key, out_vals] @@ -4523,14 +4520,14 @@ def _iota_lower(ctx, *dyn_shape, dtype, shape, dimension): aval_out = aval_out.update(shape=_merge_dyn_shape(shape, dyn_shape)) if not core.is_constant_shape(aval_out.shape): shape = mlir.eval_dynamic_shape(ctx, aval_out.shape) - return mhlo.DynamicIotaOp( + return hlo.DynamicIotaOp( mlir.aval_to_ir_type(aval_out), mlir.shape_tensor(shape), mlir.i64_attr(dimension), ).results else: - return mhlo.IotaOp(mlir.aval_to_ir_type(aval_out), - mlir.i64_attr(dimension)).results + return hlo.IotaOp(mlir.aval_to_ir_type(aval_out), + mlir.i64_attr(dimension)).results mlir.register_lowering(iota_p, _iota_lower) def _iota_batching_rule(in_vals, in_dims, *, dtype, shape, dimension): diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index 12ca08adee35..a6f3111ba401 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -52,7 +52,7 @@ from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import chlo -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.typing import Array, ArrayLike xops = xla_client.ops @@ -418,7 +418,7 @@ def _cholesky_batching_rule(batched_args, batch_dims): batching.primitive_batchers[cholesky_p] = _cholesky_batching_rule def _cholesky_lowering(ctx, x): - return mhlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results + return hlo.CholeskyOp(x, lower=ir.BoolAttr.get(True)).results mlir.register_lowering(cholesky_p, _cholesky_lowering) @@ -429,22 +429,28 @@ def _cholesky_cpu_gpu_lowering(potrf_impl, ctx, operand): out_aval, = ctx.avals_out batch_dims = operand_aval.shape[:-2] result, info = potrf_impl(operand_aval.dtype, operand, lower=True) - ok = mlir.compare_mhlo( + ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") select_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - return [_broadcasting_select_mhlo( + return [_broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_aval, broadcast_dimensions=range(len(batch_dims))), select_aval, - result, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval)] + result, out_aval, _nan_like_hlo(ctx, out_aval), out_aval)] -mlir.register_lowering( - cholesky_p, - partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo), - platform='cpu') +if xla_client.mlir_api_version < 41: + mlir.register_lowering( + cholesky_p, + partial(_cholesky_cpu_gpu_lowering, lapack.potrf_mhlo), + platform='cpu') +else: + mlir.register_lowering( + cholesky_p, + partial(_cholesky_cpu_gpu_lowering, lapack.potrf_hlo), + platform='cpu') # Asymmetric eigendecomposition @@ -491,42 +497,47 @@ def _eig_cpu_lowering(ctx, operand, *, compute_left_eigenvectors, out_aval = ctx.avals_out[0] batch_dims = operand_aval.shape[:-2] - w, vl, vr, info = lapack.geev_mhlo(operand_aval.dtype, operand, - jobvl=compute_left_eigenvectors, - jobvr=compute_right_eigenvectors) + if xla_client.mlir_api_version < 41: + w, vl, vr, info = lapack.geev_mhlo(operand_aval.dtype, operand, + jobvl=compute_left_eigenvectors, + jobvr=compute_right_eigenvectors) + else: + w, vl, vr, info = lapack.geev_hlo(operand_aval.dtype, operand, + jobvl=compute_left_eigenvectors, + jobvr=compute_right_eigenvectors) - ok = mlir.compare_mhlo( + ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_)) - w = _broadcasting_select_mhlo( + w = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_w_aval, broadcast_dimensions=range(len(batch_dims))), select_w_aval, - w, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval) + w, out_aval, _nan_like_hlo(ctx, out_aval), out_aval) output = [w] if compute_left_eigenvectors: aval = ctx.avals_out[len(output)] select_vl_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - vl = _broadcasting_select_mhlo( + vl = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_vl_aval, broadcast_dimensions=range(len(batch_dims))), select_vl_aval, - vl, aval, _nan_like_mhlo(ctx, aval), aval) + vl, aval, _nan_like_hlo(ctx, aval), aval) output.append(vl) if compute_right_eigenvectors: aval = ctx.avals_out[len(output)] select_vr_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - vr = _broadcasting_select_mhlo( + vr = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_vr_aval, broadcast_dimensions=range(len(batch_dims))), select_vr_aval, - vr, aval, _nan_like_mhlo(ctx, aval), aval) + vr, aval, _nan_like_hlo(ctx, aval), aval) output.append(vr) return output @@ -645,21 +656,21 @@ def _eigh_cpu_gpu_lowering(syevd_impl, ctx, operand, *, lower, batch_dims = operand_aval.shape[:-2] v, w, info = syevd_impl(operand_aval.dtype, operand, lower=lower) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_mhlo(info, zeros, "EQ", "SIGNED") + ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - v = _broadcasting_select_mhlo( + v = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_v_aval, broadcast_dimensions=range(len(batch_dims))), select_v_aval, - v, v_aval, _nan_like_mhlo(ctx, v_aval), v_aval) + v, v_aval, _nan_like_hlo(ctx, v_aval), v_aval) select_w_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_)) - w = _broadcasting_select_mhlo( + w = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_w_aval, broadcast_dimensions=range(len(batch_dims))), select_w_aval, - w, w_aval, _nan_like_mhlo(ctx, w_aval), w_aval) + w, w_aval, _nan_like_hlo(ctx, w_aval), w_aval) return [v, w] def _eigh_tpu_impl(x, *, lower, sort_eigenvalues): @@ -742,9 +753,14 @@ def _eigh_batching_rule(batched_args, batch_dims, *, lower, sort_eigenvalues): ad.primitive_jvps[eigh_p] = _eigh_jvp_rule batching.primitive_batchers[eigh_p] = _eigh_batching_rule -mlir.register_lowering( - eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo), - platform='cpu') +if xla_client.mlir_api_version < 41: + mlir.register_lowering( + eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_mhlo), + platform='cpu') +else: + mlir.register_lowering( + eigh_p, partial(_eigh_cpu_gpu_lowering, lapack.syevd_hlo), + platform='cpu') if gpu_solver is not None: mlir.register_lowering( @@ -882,15 +898,15 @@ def _triangular_solve_lowering( else: transpose = "ADJOINT" if conjugate_a else "TRANSPOSE" if mlir_api_version < 36: - return mhlo.TriangularSolveOp( + return hlo.TriangularSolveOp( mlir.aval_to_ir_type(out_aval), a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + hlo.TransposeAttr.get(transpose)).results else: - return mhlo.TriangularSolveOp( + return hlo.TriangularSolveOp( a, b, ir.BoolAttr.get(left_side), ir.BoolAttr.get(lower), ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + hlo.TransposeAttr.get(transpose)).results mlir.register_lowering(triangular_solve_p, _triangular_solve_lowering) @@ -904,9 +920,14 @@ def _triangular_solve_cpu_lower( conjugate_a = False if len(a_aval.shape) == 2 and np.dtype(a_aval.dtype) in _cpu_lapack_types: alpha = mlir.ir_constant(np.array(1, dtype=a_aval.dtype)) - return [lapack.trsm_mhlo( - a_aval.dtype, alpha, - a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)] + if xla_client.mlir_api_version < 41: + return [lapack.trsm_mhlo( + a_aval.dtype, alpha, + a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)] + else: + return [lapack.trsm_hlo( + a_aval.dtype, alpha, + a, b, left_side, lower, transpose_a, conjugate_a, unit_diagonal)] else: # Fall back to the HLO implementation for unsupported types or batching. # TODO: Consider swapping XLA for LAPACK in batched case @@ -915,15 +936,15 @@ def _triangular_solve_cpu_lower( else: transpose = "NO_TRANSPOSE" if mlir_api_version < 36: - return mhlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + return hlo.TriangularSolveOp(b.type, a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + hlo.TransposeAttr.get(transpose)).results else: - return mhlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side), - ir.BoolAttr.get(lower), - ir.BoolAttr.get(unit_diagonal), - mhlo.TransposeAttr.get(transpose)).results + return hlo.TriangularSolveOp(a, b, ir.BoolAttr.get(left_side), + ir.BoolAttr.get(lower), + ir.BoolAttr.get(unit_diagonal), + hlo.TransposeAttr.get(transpose)).results mlir.register_lowering(triangular_solve_p, _triangular_solve_cpu_lower, platform='cpu') @@ -1186,17 +1207,17 @@ def _lu_cpu_gpu_lowering(getrf_impl, ctx, operand): m = operand_aval.shape[-2] lu, pivot, info = getrf_impl(operand_aval.dtype, operand) # Subtract 1 from the pivot to get 0-based indices. - pivot = mhlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result - ok = mlir.compare_mhlo( + pivot = hlo.SubtractOp(pivot, mlir.full_like_aval(ctx, 1, pivot_aval)).result + ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "GE", "SIGNED") select_lu_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - lu = _broadcasting_select_mhlo( + lu = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_lu_aval, broadcast_dimensions=range(len(batch_dims))), select_lu_aval, - lu, out_aval, _nan_like_mhlo(ctx, out_aval), out_aval) + lu, out_aval, _nan_like_hlo(ctx, out_aval), out_aval) sub_ctx = ctx.replace(primitive=None, avals_in=[pivot_aval], avals_out=[perm_aval]) perm_fn = mlir.lower_fun(lambda x: lu_pivots_to_permutation(x, m), multiple_results=False) @@ -1216,9 +1237,14 @@ def _lu_tpu_translation_rule(ctx, avals_in, avals_out, operand): ad.primitive_jvps[lu_p] = _lu_jvp_rule batching.primitive_batchers[lu_p] = _lu_batching_rule -mlir.register_lowering(lu_p, - partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo), - platform='cpu') +if xla_client.mlir_api_version < 41: + mlir.register_lowering(lu_p, + partial(_lu_cpu_gpu_lowering, lapack.getrf_mhlo), + platform='cpu') +else: + mlir.register_lowering(lu_p, + partial(_lu_cpu_gpu_lowering, lapack.getrf_hlo), + platform='cpu') mlir.register_lowering( lu_p, partial(_lu_cpu_gpu_lowering, gpu_solver.cuda_getrf), @@ -1339,15 +1365,15 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a): else: a_out, taus, info_geqrf = geqrf_impl(a_aval.dtype, a) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_mhlo(info_geqrf, zeros, "EQ", "SIGNED") + ok = mlir.compare_hlo(info_geqrf, zeros, "EQ", "SIGNED") select_ok_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) ok_a = mlir.broadcast_in_dim(ctx, ok, select_ok_a_aval, broadcast_dimensions=range(len(batch_dims))) - a_out = _broadcasting_select_mhlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_mhlo(ctx, a_aval), a_aval) + a_out = _broadcasting_select_hlo(ctx, ok_a, select_ok_a_aval, a_out, a_aval, _nan_like_hlo(ctx, a_aval), a_aval) select_ok_taus_aval = ShapedArray(batch_dims + [1], np.dtype(np.bool_)) ok_taus = mlir.broadcast_in_dim(ctx, ok, select_ok_taus_aval, broadcast_dimensions=range(len(batch_dims))) - taus = _broadcasting_select_mhlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_mhlo(ctx, taus_aval), taus_aval) + taus = _broadcasting_select_hlo(ctx, ok_taus, select_ok_taus_aval, taus, taus_aval, _nan_like_hlo(ctx, taus_aval), taus_aval) return a_out, taus geqrf_p = Primitive('geqrf') @@ -1357,9 +1383,14 @@ def _geqrf_cpu_gpu_lowering(geqrf_impl, batched_geqrf_impl, ctx, a): batching.primitive_batchers[geqrf_p] = _geqrf_batching_rule xla.register_translation(geqrf_p, _geqrf_translation_rule) -mlir.register_lowering( - geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None), - platform='cpu') +if xla_client.mlir_api_version < 41: + mlir.register_lowering( + geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_mhlo, None), + platform='cpu') +else: + mlir.register_lowering( + geqrf_p, partial(_geqrf_cpu_gpu_lowering, lapack.geqrf_hlo, None), + platform='cpu') mlir.register_lowering( geqrf_p, partial(_geqrf_cpu_gpu_lowering, gpu_solver.cuda_geqrf, @@ -1425,11 +1456,11 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus): a, info_orgqr = orgqr_impl(a_aval.dtype, a, taus) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_mhlo(info_orgqr, zeros, "EQ", "SIGNED") + ok = mlir.compare_hlo(info_orgqr, zeros, "EQ", "SIGNED") select_a_aval = ShapedArray(batch_dims + [1, 1], np.dtype(np.bool_)) ok = mlir.broadcast_in_dim(ctx, ok, select_a_aval, broadcast_dimensions=range(len(batch_dims))) - a = _broadcasting_select_mhlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_mhlo(ctx, a_aval), a_aval) + a = _broadcasting_select_hlo(ctx, ok, select_a_aval, a, a_aval, _nan_like_hlo(ctx, a_aval), a_aval) return [a] @@ -1439,10 +1470,16 @@ def _householder_product_cpu_gpu_lowering(orgqr_impl, ctx, a, taus): batching.primitive_batchers[householder_product_p] = _householder_product_batching_rule xla.register_translation(householder_product_p, _householder_product_translation_rule) -mlir.register_lowering( - householder_product_p, - partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo), - platform='cpu') +if xla_client.mlir_api_version < 41: + mlir.register_lowering( + householder_product_p, + partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_mhlo), + platform='cpu') +else: + mlir.register_lowering( + householder_product_p, + partial(_householder_product_cpu_gpu_lowering, lapack.orgqr_hlo), + platform='cpu') mlir.register_lowering( householder_product_p, partial(_householder_product_cpu_gpu_lowering, gpu_solver.cuda_orgqr), @@ -1642,32 +1679,32 @@ def _svd_cpu_gpu_lowering(gesvd_impl, ctx, operand, *, full_matrices, full_matrices=full_matrices, compute_uv=compute_uv) zeros = mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))) - ok = mlir.compare_mhlo(info, zeros, "EQ", "SIGNED") + ok = mlir.compare_hlo(info, zeros, "EQ", "SIGNED") select_s_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_)) - s = _broadcasting_select_mhlo( + s = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_s_aval, broadcast_dimensions=range(len(batch_dims))), select_s_aval, - s, s_aval, _nan_like_mhlo(ctx, s_aval), s_aval) + s, s_aval, _nan_like_hlo(ctx, s_aval), s_aval) result = [s] if compute_uv: u_aval, vt_aval = ctx.avals_out[1:] select_u_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - u = _broadcasting_select_mhlo( + u = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_u_aval, broadcast_dimensions=range(len(batch_dims))), select_u_aval, - u, u_aval, _nan_like_mhlo(ctx, u_aval), u_aval) + u, u_aval, _nan_like_hlo(ctx, u_aval), u_aval) select_v_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - vt = _broadcasting_select_mhlo( + vt = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_v_aval, broadcast_dimensions=range(len(batch_dims))), select_v_aval, - vt, vt_aval, _nan_like_mhlo(ctx, vt_aval), vt_aval) + vt, vt_aval, _nan_like_hlo(ctx, vt_aval), vt_aval) result += [u, vt] return result @@ -1715,9 +1752,14 @@ def _svd_batching_rule(batched_args, batch_dims, *, full_matrices, compute_uv): ad.primitive_jvps[svd_p] = _svd_jvp_rule batching.primitive_batchers[svd_p] = _svd_batching_rule -mlir.register_lowering( - svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo), - platform='cpu') +if xla_client.mlir_api_version < 41: + mlir.register_lowering( + svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_mhlo), + platform='cpu') +else: + mlir.register_lowering( + svd_p, partial(_svd_cpu_gpu_lowering, lapack.gesdd_hlo), + platform='cpu') mlir.register_lowering( svd_p, partial(_svd_cpu_gpu_lowering, gpu_solver.cuda_gesvd), platform='cuda') @@ -1877,33 +1919,40 @@ def _schur_cpu_lowering(ctx, operand, *, compute_schur_vectors, sort_eig_vals, operand_aval, = ctx.avals_in batch_dims = operand_aval.shape[:-2] - gees_result = lapack.gees_mhlo(operand_aval.dtype, operand, + if xla_client.mlir_api_version < 41: + gees_result = lapack.gees_mhlo(operand_aval.dtype, operand, + jobvs=compute_schur_vectors, + sort=sort_eig_vals, + select=select_callable) + else: + gees_result = lapack.gees_hlo(operand_aval.dtype, operand, jobvs=compute_schur_vectors, sort=sort_eig_vals, select=select_callable) + # Number of return values depends on value of sort_eig_vals. T, vs, *_, info = gees_result - ok = mlir.compare_mhlo( + ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") select_T_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - T = _broadcasting_select_mhlo( + T = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_T_aval, broadcast_dimensions=range(len(batch_dims))), select_T_aval, - T, ctx.avals_out[0],_nan_like_mhlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]) + T, ctx.avals_out[0],_nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]) output = [T] if compute_schur_vectors: select_vs_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) - vs = _broadcasting_select_mhlo( + vs = _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_vs_aval, broadcast_dimensions=range(len(batch_dims))), select_vs_aval, - vs, ctx.avals_out[1], _nan_like_mhlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]) + vs, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]) output.append(vs) @@ -1983,35 +2032,38 @@ def _hessenberg_batching_rule(batched_args, batch_dims): batching.primitive_batchers[hessenberg_p] = _hessenberg_batching_rule -def _hessenberg_cpu_mhlo(ctx, a): +def _hessenberg_cpu_hlo(ctx, a): # TODO(phawkins): remove this test after jaxlib 0.3.25 is the minimum. - if not hasattr(lapack, "gehrd_mhlo"): + if not hasattr(lapack, "gehrd_mhlo") and not hasattr(lapack, "gehrd_hlo"): raise RuntimeError("Hessenberg reduction on CPU requires jaxlib 0.3.25 or " "newer") a_aval, = ctx.avals_in batch_dims = a_aval.shape[:-2] - a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a) - ok = mlir.compare_mhlo( + if xla_client.mlir_api_version < 41: + a, taus, info = lapack.gehrd_mhlo(a_aval.dtype, a) + else: + a, taus, info = lapack.gehrd_hlo(a_aval.dtype, a) + ok = mlir.compare_hlo( info, mlir.full_like_aval(ctx, 0, ShapedArray(batch_dims, np.dtype(np.int32))), "EQ", "SIGNED") select_a_aval = ShapedArray(batch_dims + (1, 1), np.dtype(np.bool_)) select_taus_aval = ShapedArray(batch_dims + (1,), np.dtype(np.bool_)) return [ - _broadcasting_select_mhlo( + _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_a_aval, broadcast_dimensions=range(len(batch_dims))), select_a_aval, - a, ctx.avals_out[0], _nan_like_mhlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]), - _broadcasting_select_mhlo( + a, ctx.avals_out[0], _nan_like_hlo(ctx, ctx.avals_out[0]), ctx.avals_out[0]), + _broadcasting_select_hlo( ctx, mlir.broadcast_in_dim(ctx, ok, select_taus_aval, broadcast_dimensions=range(len(batch_dims))), select_taus_aval, - taus, ctx.avals_out[1], _nan_like_mhlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]), + taus, ctx.avals_out[1], _nan_like_hlo(ctx, ctx.avals_out[1]), ctx.avals_out[1]), ] -mlir.register_lowering(hessenberg_p, _hessenberg_cpu_mhlo, platform='cpu') +mlir.register_lowering(hessenberg_p, _hessenberg_cpu_hlo, platform='cpu') # tridiagonal: Upper Hessenberg reduction @@ -2085,35 +2137,40 @@ def _tridiagonal_batching_rule(batched_args, batch_dims, *, lower): batching.primitive_batchers[tridiagonal_p] = _tridiagonal_batching_rule -def _tridiagonal_cpu_gpu_mhlo(sytrd_impl, ctx, a, *, lower): +def _tridiagonal_cpu_gpu_hlo(sytrd_impl, ctx, a, *, lower): a_aval, = ctx.avals_in a, d, e, taus, info = sytrd_impl(a_aval.dtype, a, lower=lower) return a, d, e, taus, info if jaxlib_version >= (0, 3, 25): + if xla_client.mlir_api_version < 41: + mlir.register_lowering( + tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_mhlo), + platform='cpu') + else: + mlir.register_lowering( + tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, lapack.sytrd_hlo), + platform='cpu') mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, lapack.sytrd_mhlo), - platform='cpu') - mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.cuda_sytrd), + tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.cuda_sytrd), platform='cuda') mlir.register_lowering( - tridiagonal_p, partial(_tridiagonal_cpu_gpu_mhlo, gpu_solver.rocm_sytrd), + tridiagonal_p, partial(_tridiagonal_cpu_gpu_hlo, gpu_solver.rocm_sytrd), platform='rocm') # Utilities -def _nan_like_mhlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: +def _nan_like_hlo(ctx: mlir.LoweringRuleContext, aval) -> ir.Value: if jnp.issubdtype(aval.dtype, np.complexfloating): return mlir.full_like_aval(ctx, np.nan + np.nan * 1j, aval) else: return mlir.full_like_aval(ctx, np.nan, aval) -def _broadcasting_select_mhlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value: +def _broadcasting_select_hlo(ctx, which, which_aval, x, x_aval, y, y_aval) -> ir.Value: """Wrapper around XLA `Select` that broadcasts its arguments.""" out_shapes = list(lax_internal.broadcast_shapes( tuple(which_aval.shape), tuple(x_aval.shape), tuple(y_aval.shape))) which, x, y = mlir.multi_broadcast_in_dim(ctx, (which, x, y), (which_aval, x_aval, y_aval), out_shapes) - return mhlo.SelectOp(which, x, y).result + return hlo.SelectOp(which, x, y).result diff --git a/jax/_src/lax/parallel.py b/jax/_src/lax/parallel.py index c4b0b74436f3..3a896a9fa383 100644 --- a/jax/_src/lax/parallel.py +++ b/jax/_src/lax/parallel.py @@ -39,7 +39,7 @@ from jax._src.util import unzip2, prod, canonicalize_axis, safe_map, safe_zip, moveaxis from jax._src.lib.mlir import ir from jax._src.lib import mlir_api_version -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo unsafe_map, map = map, safe_map # type: ignore @@ -219,7 +219,7 @@ def ppermute(x, axis_name, perm): If ``x`` is a pytree then the result is equivalent to mapping this function to each leaf in the tree. - This function is an analog of the CollectivePermute XLA HLO. + This function is an analog of the CollectivePermute HLO. Args: x: array(s) with a mapped axis named ``axis_name``. @@ -661,7 +661,7 @@ def _replica_groups(axis_env, axis_name, axis_index_groups): for axis_index_group in axis_index_groups] return replica_groups -def _replica_groups_mhlo(replica_groups: Sequence[Sequence[int]] +def _replica_groups_hlo(replica_groups: Sequence[Sequence[int]] ) -> ir.DenseIntElementsAttr: # Uneven replica groups are padded with -1. groups = np.array(list(itertools.zip_longest(*replica_groups, fillvalue=-1)), @@ -711,7 +711,7 @@ def _positional_reduce(aval, arg): if not named_axes: return args - replica_groups = _replica_groups_mhlo( + replica_groups = _replica_groups_hlo( _replica_groups(ctx.module_context.axis_env, named_axes, axis_index_groups)) axis_context = ctx.module_context.axis_context @@ -722,12 +722,12 @@ def all_reduce(aval, x): if is_spmd: channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=mhlo.ChannelHandle.get( + channel_handle=hlo.ChannelHandle.get( channel, mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} - op = mhlo.AllReduceOp( + op = hlo.AllReduceOp( x.type, x, replica_groups=replica_groups, **other_args) scalar_aval = core.ShapedArray((), aval.dtype) scalar_type = mlir.aval_to_ir_type(scalar_aval) @@ -738,7 +738,7 @@ def all_reduce(aval, x): avals_in=[scalar_aval] * 2, avals_out=[scalar_aval]) out_nodes = lower_reducer( reducer_ctx, *([a] for a in reducer_block.arguments)) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) return op.result return [all_reduce(aval, x) for aval, x in zip(ctx.avals_in, args)] @@ -849,11 +849,11 @@ def _ppermute_lowering(ctx, x, *, axis_name, perm): if is_manual and mlir_api_version >= 35: channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=mhlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)) + channel_handle=hlo.ChannelHandle.get(channel, mlir.DEVICE_TO_DEVICE_TYPE)) else: other_args = {} - return mhlo.CollectivePermuteOp( + return hlo.CollectivePermuteOp( x, mlir.dense_int_elements(full_perm), **other_args).results def _ppermute_transpose_rule(t, x, perm, axis_name): @@ -952,16 +952,16 @@ def _all_to_all_lowering(ctx, x, *, # of partitions - and XLA is configured with only a single replica. channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=mhlo.ChannelHandle.get(channel, - mlir.DEVICE_TO_DEVICE_TYPE)) + channel_handle=hlo.ChannelHandle.get(channel, + mlir.DEVICE_TO_DEVICE_TYPE)) else: other_args = {} - return mhlo.AllToAllOp( + return hlo.AllToAllOp( operand, split_dimension=mlir.i64_attr(split_axis), concat_dimension=mlir.i64_attr(concat_axis), split_count=mlir.i64_attr(split_count), - replica_groups=_replica_groups_mhlo(replica_groups), + replica_groups=_replica_groups_hlo(replica_groups), **other_args).results else: warnings.warn( @@ -1184,7 +1184,7 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, new_shape = list(x_aval.shape) new_shape.insert(all_gather_dimension, 1) broadcast_dimensions = [i for i in range(len(new_shape)) if i != all_gather_dimension] - x = mhlo.BroadcastInDimOp( + x = hlo.BroadcastInDimOp( mlir.aval_to_ir_type(x_aval.update(shape=new_shape)), x, mlir.dense_int_elements(broadcast_dimensions)) replica_groups = _replica_groups(ctx.module_context.axis_env, axis_name, @@ -1195,15 +1195,15 @@ def _all_gather_lowering(ctx, x, *, all_gather_dimension, axis_name, # of partitions - and XLA is configured with only a single replica. channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=mhlo.ChannelHandle.get( + channel_handle=hlo.ChannelHandle.get( channel, mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} - return mhlo.AllGatherOp( + return hlo.AllGatherOp( mlir.aval_to_ir_type(out_aval), x, all_gather_dim=mlir.i64_attr(all_gather_dimension), - replica_groups=_replica_groups_mhlo(replica_groups), + replica_groups=_replica_groups_hlo(replica_groups), **other_args).results else: lowering = mlir.lower_fun(_all_gather_via_psum, multiple_results=False) @@ -1328,16 +1328,16 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x, # of partitions - and XLA is configured with only a single replica. channel = ctx.module_context.new_channel() other_args = dict( - channel_handle=mhlo.ChannelHandle.get( + channel_handle=hlo.ChannelHandle.get( channel, mlir.DEVICE_TO_DEVICE_TYPE), use_global_device_ids=ir.BoolAttr.get(True)) else: other_args = {} - op = mhlo.ReduceScatterOp( + op = hlo.ReduceScatterOp( mlir.aval_to_ir_type(x_aval.update(shape=scatter_out_shape)), x, scatter_dimension=mlir.i64_attr(scatter_dimension), - replica_groups=_replica_groups_mhlo(replica_groups), + replica_groups=_replica_groups_hlo(replica_groups), **other_args) scalar_type = mlir.aval_to_ir_type(scalar_aval) reducer_block = op.regions[0].blocks.append(scalar_type, scalar_type) @@ -1348,12 +1348,12 @@ def _reduce_scatter_lowering(prim, reducer, ctx, x, avals_out=[scalar_aval]) out_nodes = lower_reducer( reducer_ctx, *([a] for a in reducer_block.arguments)) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) if tiled: return op.results else: - return mhlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results + return hlo.ReshapeOp(mlir.aval_to_ir_type(aval_out), op.result).results else: return mlir.lower_fun(_reduce_scatter_via_reducer, multiple_results=False)( ctx, x, @@ -1522,7 +1522,7 @@ def psum_scatter(x, axis_name, *, scatter_dimension=0, axis_index_groups=None, t return tree_util.tree_map(bind, x) -def _build_axis_index_lowering_mhlo(ctx, axis_name, axis_env): +def _build_axis_index_lowering_hlo(ctx, axis_name, axis_env): if isinstance(axis_name, tuple): assert axis_name, 'empty axis name' if len(axis_name) > 1: @@ -1539,21 +1539,21 @@ def _build_axis_index_lowering_mhlo(ctx, axis_name, axis_env): (mlir.SPMDAxisContext, mlir.ShardingContext)) if is_spmd: if mlir_api_version >= 39: - device_id = mhlo.PartitionIdOp() + device_id = hlo.PartitionIdOp() else: - device_id = mhlo.PartitionIdOp( + device_id = hlo.PartitionIdOp( ir.RankedTensorType.get([], ir.IntegerType.get_unsigned(32))) else: - device_id = mhlo.ReplicaIdOp() - unsigned_index = mhlo.RemOp(mhlo.DivOp(device_id, div), mod) - return mhlo.ConvertOp( + device_id = hlo.ReplicaIdOp() + unsigned_index = hlo.RemOp(hlo.DivOp(device_id, div), mod) + return hlo.ConvertOp( ir.RankedTensorType.get([], ir.IntegerType.get_signless(32)), unsigned_index).result def _axis_index_lowering(ctx, *, axis_name): return [ - _build_axis_index_lowering_mhlo(ctx, axis_name, - ctx.module_context.axis_env) + _build_axis_index_lowering_hlo(ctx, axis_name, + ctx.module_context.axis_env) ] diff --git a/jax/_src/lax/slicing.py b/jax/_src/lax/slicing.py index 652cf1d1f8b4..f7440e37d826 100644 --- a/jax/_src/lax/slicing.py +++ b/jax/_src/lax/slicing.py @@ -36,7 +36,7 @@ from jax._src import util from jax._src.util import safe_map, safe_zip from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_bridge from jax._src.lib import xla_client from jax._src.typing import Array, ArrayLike, Shape @@ -956,7 +956,7 @@ def _dynamic_slice_lower(ctx, x, *starts_and_dyn_sizes, slice_sizes): # def _getslice_lower(ctx, x, lo, hi): # aval_out, = ctx.avals_out -# return mhlo.RealDynamicSliceOp( +# return hlo.RealDynamicSliceOp( # mlir.aval_to_ir_type(aval_out), x, # mlir.shape_tensor([lo]), mlir.shape_tensor([hi]), mlir.shape_tensor([1]) # ).results @@ -1393,7 +1393,7 @@ def _gather_lower(ctx, operand, indices, *, assert mode in (GatherScatterMode.PROMISE_IN_BOUNDS, GatherScatterMode.CLIP), mode - dnums = mhlo.GatherDimensionNumbers.get( + dnums = hlo.GatherDimensionNumbers.get( collapsed_slice_dims=list(dimension_numbers.collapsed_slice_dims), index_vector_dim=len(ctx.avals_in[1].shape) - 1, offset_dims=list(dimension_numbers.offset_dims), @@ -1402,7 +1402,7 @@ def _gather_lower(ctx, operand, indices, *, slice_sizes = mlir.eval_dynamic_shape(ctx, slice_sizes) # TODO(burmako): Fix overly conservative type inference of DynamicGatherOp. # For now use the build_generic so that we can specify the result type. - # return mhlo.DynamicGatherOp( + # return hlo.DynamicGatherOp( # operand, indices, mlir.shape_tensor(slice_sizes), # dnums, indices_are_sorted=ir.BoolAttr.get(indices_are_sorted)).results results = [mlir.aval_to_ir_type(aval_out)] @@ -1411,10 +1411,10 @@ def _gather_lower(ctx, operand, indices, *, "dimension_numbers": dnums, "indices_are_sorted": ir.BoolAttr.get(indices_are_sorted) } - return mhlo.DynamicGatherOp.build_generic( + return hlo.DynamicGatherOp.build_generic( results=results, operands=operands, attributes=attributes).results else: - return mhlo.GatherOp( + return hlo.GatherOp( operand, indices, dnums, @@ -2019,7 +2019,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, aval_out, = ctx.avals_out dnums = dimension_numbers - scatter_dnums = mhlo.ScatterDimensionNumbers.get( + scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), @@ -2027,7 +2027,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, result = mlir.aval_to_ir_types(aval_out) operand = [operand] updates = [updates] - op = mhlo.ScatterOp( + op = hlo.ScatterOp( result, operand, indices, @@ -2045,7 +2045,7 @@ def _scatter_lower(ctx, operand, indices, updates, *, update_ctx, update_jaxpr, mlir.TokenSet(), update_consts, (update.arguments[0],), (update.arguments[1],), dim_var_values=ctx.dim_var_values) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) return op.results mlir.register_lowering(scatter_p, _scatter_lower) @@ -2076,7 +2076,7 @@ def _scatter_add_lower_gpu(ctx, operand, indices, updates, aval_out, = ctx.avals_out dnums = dimension_numbers - scatter_dnums = mhlo.ScatterDimensionNumbers.get( + scatter_dnums = hlo.ScatterDimensionNumbers.get( update_window_dims=list(dnums.update_window_dims), inserted_window_dims=list(dnums.inserted_window_dims), scattered_dims_to_operand_dims=list(dnums.scatter_dims_to_operand_dims), @@ -2089,7 +2089,7 @@ def _scatter(operand_part, updates_part): operand_part = [operand_part] updates_part = [updates_part] - scatter = mhlo.ScatterOp( + scatter = hlo.ScatterOp( operand_type_part, operand_part, indices, @@ -2100,13 +2100,13 @@ def _scatter(operand_part, updates_part): scalar_type = mlir.aval_to_ir_type(core.ShapedArray((), real_dtype)) reducer = scatter.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): - add = mhlo.AddOp(*reducer.arguments).result - mhlo.ReturnOp([add]) + add = hlo.AddOp(*reducer.arguments).result + hlo.ReturnOp([add]) return scatter.result - real = _scatter(mhlo.RealOp(operand).result, mhlo.RealOp(updates).result) - imag = _scatter(mhlo.ImagOp(operand).result, mhlo.ImagOp(updates).result) - return mhlo.ComplexOp(real, imag).results + real = _scatter(hlo.RealOp(operand).result, hlo.RealOp(updates).result) + imag = _scatter(hlo.ImagOp(operand).result, hlo.ImagOp(updates).result) + return hlo.ComplexOp(real, imag).results mlir.register_lowering(scatter_add_p, _scatter_add_lower_gpu, platform="gpu") diff --git a/jax/_src/lax/windowed_reductions.py b/jax/_src/lax/windowed_reductions.py index 6307fedb512f..3963cd5028cf 100644 --- a/jax/_src/lax/windowed_reductions.py +++ b/jax/_src/lax/windowed_reductions.py @@ -33,7 +33,7 @@ import jax._src.lax.convolution as convolution import jax._src.lax.slicing as slicing from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.ufuncs import logaddexp import jax._src.util as util @@ -316,7 +316,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, operands, init_values = util.split_list(args, [len(args) // 2]) _, init_value_avals = util.split_list(ctx.avals_in, [len(operands)]) scalar_types = [mlir.aval_to_ir_type(aval) for aval in init_value_avals] - rw = mhlo.ReduceWindowOp( + rw = hlo.ReduceWindowOp( map(mlir.aval_to_ir_type, ctx.avals_out), operands, init_values, @@ -333,7 +333,7 @@ def _generic_reduce_window_lower(ctx, *args, jaxpr, consts, out_nodes, _ = mlir.jaxpr_subcomp(ctx.module_context, jaxpr, mlir.TokenSet(), consts, *([a] for a in reducer.arguments), dim_var_values=ctx.dim_var_values) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) return rw.results mlir.register_lowering(reduce_window_p, _generic_reduce_window_lower) @@ -468,7 +468,7 @@ def _reduce_window_lower( operand_aval, = ctx.avals_in scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) - rw = mhlo.ReduceWindowOp( + rw = hlo.ReduceWindowOp( mlir.aval_to_ir_types(aval_out), [operand], [mlir.full_like_aval(ctx, init_value(scalar_aval.dtype), scalar_aval)], mlir.dense_int_elements(window_dimensions), @@ -479,15 +479,15 @@ def _reduce_window_lower( shape=(len(padding), 2))) reducer = rw.regions[0].blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(reducer): - mhlo.ReturnOp(reduce_op(*reducer.arguments)) + hlo.ReturnOp(reduce_op(*reducer.arguments)) return rw.results mlir.register_lowering(reduce_window_sum_p, partial( - _reduce_window_lower, mhlo.AddOp, lambda _: 0)) + _reduce_window_lower, hlo.AddOp, lambda _: 0)) mlir.register_lowering(reduce_window_min_p, partial( - _reduce_window_lower, mlir.min_mhlo, lax._get_min_identity)) + _reduce_window_lower, mlir.min_hlo, lax._get_min_identity)) mlir.register_lowering(reduce_window_max_p, partial( - _reduce_window_lower, mlir.max_mhlo, lax._get_max_identity)) + _reduce_window_lower, mlir.max_hlo, lax._get_max_identity)) @@ -514,7 +514,7 @@ def _select_and_scatter_lower( aval_out, = ctx.avals_out scalar_aval = operand_aval.update(shape=()) scalar_type = mlir.aval_to_ir_type(scalar_aval) - op = mhlo.SelectAndScatterOp( + op = hlo.SelectAndScatterOp( mlir.aval_to_ir_type(aval_out), operand, source, @@ -531,7 +531,7 @@ def _select_and_scatter_lower( mlir.TokenSet(), select_consts, *([a] for a in select.arguments), dim_var_values=ctx.dim_var_values) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) scatter = op.scatter.blocks.append(scalar_type, scalar_type) with ir.InsertionPoint(scatter): if scatter_jaxpr.effects: @@ -540,7 +540,7 @@ def _select_and_scatter_lower( mlir.TokenSet(), scatter_consts, *([a] for a in scatter.arguments), dim_var_values=ctx.dim_var_values) - mhlo.ReturnOp(util.flatten(out_nodes)) + hlo.ReturnOp(util.flatten(out_nodes)) return op.results mlir.register_lowering(select_and_scatter_p, _select_and_scatter_lower) @@ -670,7 +670,7 @@ def _select_and_gather_add_lowering( canonicalize_types=False) def _broadcast(x, dims): - return mhlo.BroadcastOp(x, mlir.dense_int_elements(dims)) + return hlo.BroadcastOp(x, mlir.dense_int_elements(dims)) if double_word_reduction: # TODO(b/73062247): XLA doesn't yet implement ReduceWindow on tuples, so @@ -685,28 +685,28 @@ def _broadcast(x, dims): def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape - a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) - b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) - a = mhlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a) - b = mhlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b) - a = mhlo.ShiftLeftOp(a, - _broadcast(const(double_word_dtype, nbits), a_dims)) - return mhlo.OrOp(a, b) + a = hlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) + b = hlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) + a = hlo.ConvertOp(ir.RankedTensorType.get(a_dims, double_word_type), a) + b = hlo.ConvertOp(ir.RankedTensorType.get(b_dims, double_word_type), b) + a = hlo.ShiftLeftOp(a, + _broadcast(const(double_word_dtype, nbits), a_dims)) + return hlo.OrOp(a, b) # Unpacks the first element of a tuple. def fst(t): dims = ir.RankedTensorType(t.type).shape - st = mhlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) - return mhlo.BitcastConvertOp( + st = hlo.ShiftRightLogicalOp(t, const(double_word_dtype, nbits)) + return hlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), - mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result + hlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), st)).result # Unpacks the second element of a tuple. def snd(t): dims = ir.RankedTensorType(t.type).shape - return mhlo.BitcastConvertOp( + return hlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), - mhlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result + hlo.ConvertOp(ir.RankedTensorType.get(dims, word_type), t)).result else: # The double-word trick above only works if we have a sufficiently large @@ -729,33 +729,33 @@ def snd(t): def pack(a, b): a_dims = ir.RankedTensorType(a.type).shape b_dims = ir.RankedTensorType(b.type).shape - a = mhlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), - mantissa_bits=mlir.i32_attr(nmant)) - b = mhlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), - mantissa_bits=mlir.i32_attr(nmant)) - a = mhlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) - b = mhlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) - b = mhlo.ShiftRightLogicalOp( + a = hlo.ReducePrecisionOp(a, exponent_bits=mlir.i32_attr(nexp), + mantissa_bits=mlir.i32_attr(nmant)) + b = hlo.ReducePrecisionOp(b, exponent_bits=mlir.i32_attr(nexp), + mantissa_bits=mlir.i32_attr(nmant)) + a = hlo.BitcastConvertOp(ir.RankedTensorType.get(a_dims, word_type), a) + b = hlo.BitcastConvertOp(ir.RankedTensorType.get(b_dims, word_type), b) + b = hlo.ShiftRightLogicalOp( b, _broadcast(const(word_dtype, r_nbits), b_dims)) - return mhlo.OrOp(a, b) + return hlo.OrOp(a, b) # Unpacks the first element of a tuple. def fst(t): - st = mhlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) - return mhlo.BitcastConvertOp(ir.RankedTensorType.get([], etype), - st).result + st = hlo.AndOp(t, const(word_dtype, ((1 << r_nbits) - 1) << r_nbits)) + return hlo.BitcastConvertOp(ir.RankedTensorType.get([], etype), + st).result # Unpacks the second element of a tuple. def snd(t): dims = ir.RankedTensorType(t.type).shape - return mhlo.BitcastConvertOp( + return hlo.BitcastConvertOp( ir.RankedTensorType.get(dims, etype), - mhlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims)) + hlo.ShiftLeftOp(t, _broadcast(const(word_dtype, r_nbits), dims)) ).result assert select_prim is lax.ge_p or select_prim is lax.le_p, select_prim init = -np.inf if select_prim is lax.ge_p else np.inf - rw = mhlo.ReduceWindowOp( + rw = hlo.ReduceWindowOp( [ir.RankedTensorType.get(out_aval.shape, double_word_type)], pack(operand, tangents), pack(const(dtype, init), const(dtype, 0)), @@ -771,8 +771,8 @@ def snd(t): x, y = reducer.arguments assert select_prim is lax.ge_p or select_prim is lax.le_p which = "GE" if select_prim is lax.ge_p else "LE" - out = mhlo.SelectOp(mlir.compare_mhlo(fst(x), fst(y), which), x, y) - mhlo.ReturnOp(out) + out = hlo.SelectOp(mlir.compare_hlo(fst(x), fst(y), which), x, y) + hlo.ReturnOp(out) return [snd(rw.result)] # TODO(phawkins): use this translation rule on all platforms. diff --git a/jax/_src/lib/mlir/dialects/__init__.py b/jax/_src/lib/mlir/dialects/__init__.py index cb5cd235af07..5e8297b79f5e 100644 --- a/jax/_src/lib/mlir/dialects/__init__.py +++ b/jax/_src/lib/mlir/dialects/__init__.py @@ -23,3 +23,8 @@ from jax.lib import xla_client if xla_client.mlir_api_version >= 37: import jaxlib.mlir.dialects.stablehlo as stablehlo + +# Alias that is set up to abstract away the transition from MHLO to StableHLO. +# At the moment, it points to MHLO, but in the future it will start to +# conditionally and then unconditionally point to StableHLO. +import jaxlib.mlir.dialects.mhlo as hlo diff --git a/jax/_src/lib/xla_bridge.py b/jax/_src/lib/xla_bridge.py index 86eca2812665..efbb23c569e2 100644 --- a/jax/_src/lib/xla_bridge.py +++ b/jax/_src/lib/xla_bridge.py @@ -280,7 +280,7 @@ def canonicalize_platform(platform: str) -> str: In particular, replaces "gpu" with either "cuda" or "rocm", depending on which hardware is actually present. We want to distinguish "cuda" and "rocm" for - purposes such as MHLO lowering rules, but in many cases we don't want to + purposes such as MLIR lowering rules, but in many cases we don't want to force users to care. """ platforms = _alias_to_platforms.get(platform, None) diff --git a/jax/_src/prng.py b/jax/_src/prng.py index 6823a2224db6..cdc6221bd6f4 100644 --- a/jax/_src/prng.py +++ b/jax/_src/prng.py @@ -40,7 +40,7 @@ from jax._src.api import jit, vmap from jax._src.lax import lax as lax_internal from jax._src.lax import utils as lax_utils -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy import lax_numpy import jax._src.pretty_printer as pp from jax._src.util import canonicalize_axis, prod, safe_map, safe_zip @@ -443,7 +443,7 @@ def transpose_mlir(ctx, aval_out, x, *, permutation) -> mlir.ir.Value: key_shape = aval_out.dtype.impl.key_shape trailing_dims = [aval_out.ndim + i for i in range(len(key_shape))] perm = [*permutation, *trailing_dims] - return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result + return hlo.TransposeOp(x, mlir.dense_int_elements(perm)).result @staticmethod def gather_mlir(ctx, avals_in, aval_out, x, indices, *, @@ -1041,27 +1041,27 @@ def bcast_iotas_to_reshaped_iota(add, mul, shape, iotas): def iota_2x32_shape_lowering(ctx, *, shape): def _add(x, y): - return mlir.mhlo.AddOp(x, y).result + return mlir.hlo.AddOp(x, y).result def _mul(x, y): x_const = mlir.ir_constant(np.array(x, np.dtype('uint64')), canonicalize_types=False) - x_bcast = mlir.mhlo.BroadcastOp(x_const, mlir.dense_int_elements(shape)) - return mlir.mhlo.MulOp(x_bcast, y).result + x_bcast = mlir.hlo.BroadcastOp(x_const, mlir.dense_int_elements(shape)) + return mlir.hlo.MulOp(x_bcast, y).result assert len(shape) > 0 aval_out, _ = ctx.avals_out aval_u64 = core.ShapedArray(shape, np.dtype('uint64')) - iotas = [mlir.mhlo.IotaOp(mlir.aval_to_ir_type(aval_u64), + iotas = [mlir.hlo.IotaOp(mlir.aval_to_ir_type(aval_u64), mlir.i64_attr(dimension)).result for dimension in range(len(shape))] counts = bcast_iotas_to_reshaped_iota(_add, _mul, shape, iotas) shift = mlir.ir_constant(np.array(32, np.dtype('uint64')), canonicalize_types=False) - shift = mlir.mhlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result - counts_shifted = mlir.mhlo.ShiftRightLogicalOp(counts, shift).result - counts_lo = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result - counts_hi = mlir.mhlo.ConvertOp(mlir.aval_to_ir_type(aval_out), + shift = mlir.hlo.BroadcastOp(shift, mlir.dense_int_elements(shape)).result + counts_shifted = mlir.hlo.ShiftRightLogicalOp(counts, shift).result + counts_lo = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts).result + counts_hi = mlir.hlo.ConvertOp(mlir.aval_to_ir_type(aval_out), counts_shifted).result return counts_hi, counts_lo mlir.register_lowering(iota_2x32_shape_p, iota_2x32_shape_lowering) diff --git a/jax/experimental/custom_partitioning.py b/jax/experimental/custom_partitioning.py index 69dcd931d8d6..124583907c65 100644 --- a/jax/experimental/custom_partitioning.py +++ b/jax/experimental/custom_partitioning.py @@ -20,7 +20,7 @@ from jax import linear_util as lu from jax.experimental import pjit -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir import ir import jax.interpreters.pxla as pxla from jax.interpreters import mlir @@ -245,7 +245,7 @@ def to_mesh_pspec_sharding(op_sharding: xc.OpSharding): else: out_type = [ir.TupleType.get_tuple(mlir_shapes)] - out = mhlo.CustomCallOp( + out = hlo.CustomCallOp( out_type, list(values), call_target_name=ir.StringAttr.get(_CUSTOM_PARTITIONING_CALL_NAME), @@ -259,7 +259,7 @@ def to_mesh_pspec_sharding(op_sharding: xc.OpSharding): return [out.result] else: return [ - mhlo.GetTupleElementOp(out, mlir.i32_attr(i)).result + hlo.GetTupleElementOp(out, mlir.i32_attr(i)).result for i in range(len(mlir_shapes)) ] diff --git a/jax/experimental/host_callback.py b/jax/experimental/host_callback.py index 9bd46e6db6b0..4c1dde2d5847 100644 --- a/jax/experimental/host_callback.py +++ b/jax/experimental/host_callback.py @@ -524,7 +524,7 @@ def power3_with_cotangents(x): from jax._src.lib import xla_bridge as xb from jax._src.lib import xla_client from jax._src.lib import xla_extension -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo import numpy as np @@ -1143,14 +1143,14 @@ def _outside_call_lowering( assert has_token current_token = args[-2] current_itoken = args[-1] - assert current_token.type == mhlo.TokenType.get(), "The last two arguments must be tokens" - assert current_itoken.type == mhlo.TokenType.get(), "The last two arguments must be tokens" + assert current_token.type == hlo.TokenType.get(), "The last two arguments must be tokens" + assert current_itoken.type == hlo.TokenType.get(), "The last two arguments must be tokens" args_to_outfeed = args[:-2] # TODO(necula): this is a weak attempt to get the device. This works # inside pmap, but does not work when we just execute on a single device, # because in such executions we always get replica_id == 0. - replica_id = mhlo.ReplicaIdOp() + replica_id = hlo.ReplicaIdOp() callback_operands = [replica_id, *args_to_outfeed] callback_operand_avals = [ core.ShapedArray((), np.uint32), *ctx.avals_in[:-2]] diff --git a/jax/experimental/jax2tf/call_tf.py b/jax/experimental/jax2tf/call_tf.py index 280d98bdccfb..22c751b448e8 100644 --- a/jax/experimental/jax2tf/call_tf.py +++ b/jax/experimental/jax2tf/call_tf.py @@ -39,7 +39,7 @@ from jax.interpreters import xla from jax._src.lib.mlir import ir from jax._src.lib.mlir.dialects import func as func_dialect -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib import xla_client from jax.experimental.jax2tf import jax2tf as jax2tf_internal @@ -392,16 +392,16 @@ def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value] captured_ops = tuple(mlir.ir_constant(np.asarray(inp), canonicalize_types=False) for inp in captured_inputs) - submodule = mlir.xla_computation_to_mhlo_module(xla_comp) + submodule = mlir.xla_computation_to_mlir_module(xla_comp) symtab = ir.SymbolTable(submodule.operation) callee_result_types = symtab["main"].type.results - fn = mlir.merge_mhlo_modules(ctx.module, f"call_tf_{function_flat_tf.name}", + fn = mlir.merge_mlir_modules(ctx.module, f"call_tf_{function_flat_tf.name}", submodule) call = func_dialect.CallOp(callee_result_types, ir.FlatSymbolRefAttr.get(fn), tuple(args_op) + captured_ops) if result_shape.is_tuple(): - flat_results = [mhlo.GetTupleElementOp(call, mlir.i32_attr(i)).result + flat_results = [hlo.GetTupleElementOp(call, mlir.i32_attr(i)).result for i in range(len(result_shapes))] else: flat_results = call.results @@ -410,7 +410,7 @@ def code_gen(ctx: mlir.ModuleContext, args_op: Sequence[ir.Value] for op, res_aval, res_shape in zip(flat_results, result_avals, result_shapes): if res_aval.dtype != res_shape.numpy_dtype(): - op = mhlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result + op = hlo.ConvertOp(mlir.aval_to_ir_type(res_aval), op).result outputs.append(op) return outputs diff --git a/jax/experimental/jax2tf/jax2tf.py b/jax/experimental/jax2tf/jax2tf.py index 1323ad568481..8ec8b34f6f3f 100644 --- a/jax/experimental/jax2tf/jax2tf.py +++ b/jax/experimental/jax2tf/jax2tf.py @@ -589,7 +589,7 @@ def _lower_native_and_run(fun_jax: Callable, Work-in-progress. - Uses JAX native lowering to MHLO, and then wraps the result in a + Uses JAX native lowering to MLIR, and then wraps the result in a XlaCallModule TF op. This op does not have backward-compatibility yet. Special care must be taken in presence of shape polymorphism. @@ -634,13 +634,13 @@ def _lower_native_and_run(fun_jax: Callable, fun_jax_lower = fun_jax.lower lowered = fun_jax_lower(*arg_specs_jax)._lowering if config.jax2tf_use_stablehlo: - mhlo_module = lowered.stablehlo() + mlir_module = lowered.stablehlo() xla_call_module_version = 2 else: - mhlo_module = lowered.mhlo() + mlir_module = lowered.mhlo() xla_call_module_version = 1 - mhlo_serialized_module = mlir.module_to_bytecode(mhlo_module) + mlir_serialized_module = mlir.module_to_bytecode(mlir_module) # Figure out the result types and shapes if "global_out_avals" in lowered.compile_args: # This is currently the case for pjit @@ -719,7 +719,7 @@ def _out_type(jax_type): args_tf = [atf for i, atf in enumerate(args_tf) if i in module_kept_var_idx] # Apply the shardings on arguments and results for pjit. This is redundant - # because the mhlo_module_text will already contain the shardings, but it + # because the mlir_module_text will already contain the shardings, but it # makes it easier for tools like the TPU inference converter to see the # sharding without digging into the `module` attribute of the `XlaCallModule` # op, in the same way as it is done for the legacy jax2tf conversion. @@ -728,14 +728,14 @@ def _out_type(jax_type): map(_shard_value, args_tf, args_avals, lowered.compile_args["in_shardings"])) if logging.vlog_is_on(3): - mhlo_module_text = mlir.module_to_string(mhlo_module) + mlir_module_text = mlir.module_to_string(mlir_module) logging.vlog(3, "XlaCallModule (version=%d, dim_args_spec=%s)\n%s", xla_call_module_version, ", ".join(dim_args_spec), - mhlo_module_text) + mlir_module_text) res = tfxla.call_module( args_tf, version=xla_call_module_version, - module=mhlo_serialized_module, + module=mlir_serialized_module, Tout=out_types, Sout=out_shapes, dim_args_spec=dim_args_spec) diff --git a/jax/experimental/jax2tf/tests/jax2tf_test.py b/jax/experimental/jax2tf/tests/jax2tf_test.py index b5f1525b5d37..48bfd0a57009 100644 --- a/jax/experimental/jax2tf/tests/jax2tf_test.py +++ b/jax/experimental/jax2tf/tests/jax2tf_test.py @@ -1479,7 +1479,7 @@ def f_tf(x1_tf, x2_tf): self.assertAllClose(tf.nest.map_structure(lambda t: t.numpy(), res), jax_res) - @unittest.skip("TODO(necula): 'mhlo.dynamic_iota' op can't be translated to XLA HLO") + @unittest.skip("TODO(necula): 'dynamic_iota' op can't be translated to XLA HLO") def test_shape_poly_arange(self): if not config.jax_dynamic_shapes: raise unittest.SkipTest("jax_dynamic_shapes must be enabled") diff --git a/jax/experimental/pjit.py b/jax/experimental/pjit.py index 260d45ac7f6c..799e3fa2db27 100644 --- a/jax/experimental/pjit.py +++ b/jax/experimental/pjit.py @@ -1176,7 +1176,7 @@ def _pjit_lowering(ctx, *args, name, jaxpr, in_shardings, name_stack=xla.extend_name_stack(ctx.module_context.name_stack, wrap_name(name, "pjit"))) # TODO(b/228598865): inlined calls cannot have shardings set directly on the - # inputs or outputs because they are lost during MHLO->HLO conversion. + # inputs or outputs because they are lost during MLIR->HLO conversion. # using_sharding_annotation=False means we add an identity operation instead. func = mlir.lower_jaxpr_to_fun(sub_ctx, f"pjit_{name}", jaxpr, (), arg_shardings=arg_shardings, @@ -1544,8 +1544,8 @@ def _sharding_constraint_impl(x, sharding, resource_env, unconstrained_dims): ad.deflinear2(sharding_constraint_p, lambda ct, _, **params: (sharding_constraint_p.bind(ct, **params),)) -def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding, - resource_env, unconstrained_dims): +def _sharding_constraint_hlo_lowering(ctx, x_node, *, sharding, + resource_env, unconstrained_dims): aval, = ctx.avals_in axis_ctx = ctx.module_context.axis_context # axis_ctx and manual_axes is *only used with xmap* and xmap only works with @@ -1564,7 +1564,7 @@ def _sharding_constraint_mhlo_lowering(ctx, x_node, *, sharding, unspecified_dims=unconstrained_dims) ] mlir.register_lowering(sharding_constraint_p, - _sharding_constraint_mhlo_lowering) + _sharding_constraint_hlo_lowering) def _sharding_constraint_batcher(insert_axis, spmd_axis_name, axis_size, diff --git a/jax/experimental/sparse/bcoo.py b/jax/experimental/sparse/bcoo.py index 029ef75d1171..13dd4c68a9ad 100644 --- a/jax/experimental/sparse/bcoo.py +++ b/jax/experimental/sparse/bcoo.py @@ -48,7 +48,7 @@ from jax._src.lib.mlir import ir from jax._src.lib import xla_bridge from jax._src.lib import gpu_sparse -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.numpy.setops import _unique from jax._src.typing import Array, ArrayLike, DType, DTypeLike from jax._src.util import canonicalize_axis @@ -728,13 +728,13 @@ def _bcoo_dot_general_abstract_eval(lhs_data, lhs_indices, rhs, *, dimension_num _bcoo_dot_general_default_lowering = mlir.lower_fun( _bcoo_dot_general_impl, multiple_results=False) -def _collapse_mhlo(x, start, end): +def _collapse_hlo(x, start, end): x_type = ir.RankedTensorType(x.type) shape = x_type.shape shape = (shape[:start] + [functools.reduce(operator.mul, shape[start:end + 1])] + shape[end + 1:]) - return mhlo.ReshapeOp( + return hlo.ReshapeOp( ir.RankedTensorType.get(shape, x_type.element_type), x).result def _bcoo_dot_general_cuda_lowering( @@ -766,7 +766,7 @@ def _bcoo_dot_general_cuda_lowering( elif rhs_ndim == 2: bcoo_dot_general_fn = coo_matmat_lowering if rhs_contract[0] == 1: - rhs = mhlo.TransposeOp( + rhs = hlo.TransposeOp( rhs, permutation=mlir.dense_int_elements([1, 0])).result else: raise ValueError(f"rhs has to be 1d or 2d; get {rhs_ndim}d.") @@ -776,7 +776,7 @@ def _bcoo_dot_general_cuda_lowering( lhs_transpose = False if props.n_sparse == 1: # Converts lhs to a row vector. - col = _collapse_mhlo(lhs_indices, start=0, end=1) + col = _collapse_hlo(lhs_indices, start=0, end=1) row = mlir.full_like_aval( ctx, 0, core.ShapedArray(ir.RankedTensorType(col.type).shape, np.dtype(np.int32))) @@ -788,23 +788,23 @@ def _bcoo_dot_general_cuda_lowering( if rhs_ndim == 1: # Transforms a single-element array to a scalar. - return [mhlo.ReshapeOp( + return [hlo.ReshapeOp( ir.RankedTensorType.get( [], ir.RankedTensorType(dot_product.type).element_type), dot_product).result] else: - return [_collapse_mhlo(dot_product, start=0, end=1)] + return [_collapse_hlo(dot_product, start=0, end=1)] elif props.n_sparse == 2: lhs_indices_shape = ir.RankedTensorType(lhs_indices.type).shape - row = _collapse_mhlo( - mhlo.SliceOp( + row = _collapse_hlo( + hlo.SliceOp( lhs_indices, start_indices=mlir.dense_int_elements([0, 0]), limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 1]), strides=mlir.dense_int_elements([1, 1])).result, start=0, end=1) - col = _collapse_mhlo( - mhlo.SliceOp( + col = _collapse_hlo( + hlo.SliceOp( lhs_indices, start_indices=mlir.dense_int_elements([0, 1]), limit_indices=mlir.dense_int_elements([lhs_indices_shape[0], 2]), @@ -833,28 +833,28 @@ def _bcoo_dot_general_cuda_lowering( lhs_indices_shape[-1]) lhs_data_1d_shape = (np.prod(np.array(lhs_data_shape)), ) - lhs_indices_2d = mhlo.ReshapeOp( + lhs_indices_2d = hlo.ReshapeOp( ir.RankedTensorType.get( lhs_indices_2d_shape, ir.RankedTensorType(lhs_indices.type).element_type), lhs_indices).result - lhs_data_1d = mhlo.ReshapeOp( + lhs_data_1d = hlo.ReshapeOp( ir.RankedTensorType.get( lhs_data_1d_shape, ir.RankedTensorType(lhs_data.type).element_type), lhs_data).result - row = _collapse_mhlo( - mhlo.SliceOp( + row = _collapse_hlo( + hlo.SliceOp( lhs_indices_2d, start_indices=mlir.dense_int_elements([0, 0]), limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 1]), strides=mlir.dense_int_elements([1, 1])).result, start=0, end=1) - col = _collapse_mhlo( - mhlo.SliceOp( + col = _collapse_hlo( + hlo.SliceOp( lhs_indices_2d, start_indices=mlir.dense_int_elements([0, 1]), limit_indices=mlir.dense_int_elements([lhs_indices_2d_shape[0], 2]), @@ -867,13 +867,13 @@ def _bcoo_dot_general_cuda_lowering( # The issue (https://github.com/NVIDIA/CUDALibrarySamples/issues/81#issuecomment-1205562643) # in cusparse library does not allow batch_stride = 0 for a non-batched rhs. batched_rhs_shape = (batch_count,) + tuple(rhs_shape) - batched_rhs = mhlo.BroadcastInDimOp( + batched_rhs = hlo.BroadcastInDimOp( ir.RankedTensorType.get(batched_rhs_shape, ir.RankedTensorType(rhs.type).element_type), rhs, broadcast_dimensions=mlir.dense_int_elements([1, 2])).result batched_rhs_2d_shape = (np.prod(np.array(batched_rhs_shape)[:-1]), batched_rhs_shape[-1]) - batched_rhs_2d = mhlo.ReshapeOp( + batched_rhs_2d = hlo.ReshapeOp( ir.RankedTensorType.get( batched_rhs_2d_shape, ir.RankedTensorType(batched_rhs.type).element_type), @@ -1404,12 +1404,12 @@ def _bcoo_sort_indices_jvp(primals, tangents, *, spinfo): data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot, perm) return (data_out, indices_out), (data_dot_out, indices_dot_out) -_bcoo_sort_indices_mhlo = mlir.lower_fun( +_bcoo_sort_indices_hlo = mlir.lower_fun( _bcoo_sort_indices_impl, multiple_results=True) ad.primitive_jvps[bcoo_sort_indices_p] = _bcoo_sort_indices_jvp batching.primitive_batchers[bcoo_sort_indices_p] = _bcoo_sort_indices_batching_rule -mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_mhlo) +mlir.register_lowering(bcoo_sort_indices_p, _bcoo_sort_indices_hlo) #---------------------------------------------------------------------- @@ -1560,12 +1560,12 @@ def _bcoo_sum_duplicates_jvp(primals, tangents, *, spinfo, nse): data_dot_out = ad.Zero.from_value(data_out) if type(data_dot) is ad.Zero else permute(data_dot_out, mapping, data_dot) return (data_out, indices_out), (data_dot_out, indices_dot_out) -_bcoo_sum_duplicates_mhlo = mlir.lower_fun( +_bcoo_sum_duplicates_hlo = mlir.lower_fun( _bcoo_sum_duplicates_impl, multiple_results=True) ad.primitive_jvps[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_jvp batching.primitive_batchers[bcoo_sum_duplicates_p] = _bcoo_sum_duplicates_batching_rule -mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_mhlo) +mlir.register_lowering(bcoo_sum_duplicates_p, _bcoo_sum_duplicates_hlo) #---------------------------------------------------------------------- # BCOO functions that maybe should be primitives? diff --git a/jax/experimental/sparse/coo.py b/jax/experimental/sparse/coo.py index 6778cd1f9454..70d04d583f6b 100644 --- a/jax/experimental/sparse/coo.py +++ b/jax/experimental/sparse/coo.py @@ -30,7 +30,7 @@ from jax.experimental.sparse.util import _coo_extract, CuSparseEfficiencyWarning from jax import tree_util from jax._src.lax.lax import _const -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib import gpu_sparse from jax._src.numpy.lax_numpy import _promote_dtypes from jax._src.typing import Array, ArrayLike, DTypeLike @@ -199,7 +199,7 @@ def _coo_todense_abstract_eval(data, row, col, *, spinfo): _coo_todense_lowering = mlir.lower_fun( _coo_todense_impl, multiple_results=False) -def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo): +def _coo_todense_gpu_lowering(coo_todense_hlo, ctx, data, row, col, *, spinfo): data_aval, row_aval, _ = ctx.avals_in dtype = data_aval.dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): @@ -220,10 +220,10 @@ def _coo_todense_gpu_lowering(coo_todense_mhlo, ctx, data, row, col, *, spinfo): "back to the default implementation.", CuSparseEfficiencyWarning) return _coo_todense_lowering(ctx, data, row, col, spinfo=spinfo) - result = coo_todense_mhlo( + result = coo_todense_hlo( data, row, col, shape=shape, data_dtype=dtype, index_dtype=row_aval.dtype) return ( - [mhlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result] + [hlo.TransposeOp(result, mlir.dense_int_elements([1, 0])).result] if transpose else [result]) @@ -318,14 +318,14 @@ def _coo_fromdense_abstract_eval(mat, *, nse, index_dtype): _coo_fromdense_lowering = mlir.lower_fun( _coo_fromdense_impl, multiple_results=True) -def _coo_fromdense_gpu_lowering(coo_fromdense_mhlo, ctx, mat, *, nse, +def _coo_fromdense_gpu_lowering(coo_fromdense_hlo, ctx, mat, *, nse, index_dtype): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"coo_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _coo_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, row, col = coo_fromdense_mhlo( + data, row, col = coo_fromdense_hlo( mat, nnz=nse, data_dtype=dtype, index_dtype=np.dtype(index_dtype), @@ -438,7 +438,7 @@ def _coo_matvec_abstract_eval(data, row, col, v, *, spinfo, transpose): _coo_matvec_lowering = mlir.lower_fun( _coo_matvec_impl, multiple_results=False) -def _coo_matvec_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, spinfo, +def _coo_matvec_gpu_lowering(coo_matvec_hlo, ctx, data, row, col, v, *, spinfo, transpose): data_aval, row_aval, _, x_aval = ctx.avals_in dtype = data_aval.dtype @@ -461,7 +461,7 @@ def _coo_matvec_gpu_lowering(coo_matvec_mhlo, ctx, data, row, col, v, *, spinfo, return _coo_matvec_lowering(ctx, data, row, col, v, spinfo=spinfo, transpose=transpose) - return [coo_matvec_mhlo( + return [coo_matvec_hlo( data, row, col, v, shape=shape, transpose=transpose, index_dtype=row_aval.dtype, data_dtype=dtype, x_dtype=x_aval.dtype)] @@ -561,7 +561,7 @@ def _coo_matmat_abstract_eval(data, row, col, B, *, spinfo, transpose): _coo_matmat_lowering = mlir.lower_fun(_coo_matmat_impl, multiple_results=False) -def _coo_matmat_gpu_lowering(coo_matmat_mhlo, ctx, data, row, col, B, *, spinfo, +def _coo_matmat_gpu_lowering(coo_matmat_hlo, ctx, data, row, col, B, *, spinfo, transpose): data_aval, row_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype @@ -583,7 +583,7 @@ def _coo_matmat_gpu_lowering(coo_matmat_mhlo, ctx, data, row, col, B, *, spinfo, return _coo_matmat_lowering(ctx, data, row, col, B, spinfo=spinfo, transpose=transpose) - return [coo_matmat_mhlo(data, row, col, B, shape=shape, + return [coo_matmat_hlo(data, row, col, B, shape=shape, transpose=transpose, x_dtype=B_aval.dtype, data_dtype=data_aval.dtype, index_dtype=row_aval.dtype)] diff --git a/jax/experimental/sparse/csr.py b/jax/experimental/sparse/csr.py index 3ea81b5532bd..3200c62e868e 100644 --- a/jax/experimental/sparse/csr.py +++ b/jax/experimental/sparse/csr.py @@ -227,7 +227,7 @@ def _csr_todense_abstract_eval(data, indices, indptr, *, shape): _csr_todense_lowering = mlir.lower_fun( _csr_todense_impl, multiple_results=False) -def _csr_todense_gpu_lowering(csr_todense_mhlo, ctx, data, indices, indptr, *, +def _csr_todense_gpu_lowering(csr_todense_hlo, ctx, data, indices, indptr, *, shape): data_aval, indices_aval, _ = ctx.avals_in dtype = data_aval.dtype @@ -235,7 +235,7 @@ def _csr_todense_gpu_lowering(csr_todense_mhlo, ctx, data, indices, indptr, *, warnings.warn(f"csr_todense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_todense_lowering(ctx, data, indices, indptr, shape=shape) - return [csr_todense_mhlo( + return [csr_todense_hlo( data, indices, indptr, shape=shape, data_dtype=dtype, index_dtype=indices_aval.dtype)] @@ -319,13 +319,13 @@ def _csr_fromdense_abstract_eval(mat, *, nse, index_dtype): _csr_fromdense_lowering = mlir.lower_fun(_csr_fromdense_impl, multiple_results=True) -def _csr_fromdense_gpu_lowering(csr_fromdense_mhlo, ctx, mat, *, nse, index_dtype): +def _csr_fromdense_gpu_lowering(csr_fromdense_hlo, ctx, mat, *, nse, index_dtype): dtype = ctx.avals_in[0].dtype if not (np.issubdtype(dtype, np.floating) or np.issubdtype(dtype, np.complexfloating)): warnings.warn(f"csr_fromdense cusparse/hipsparse lowering not available for {dtype=}. " "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_fromdense_lowering(ctx, mat, nse=nse, index_dtype=index_dtype) - data, indices, indptr = csr_fromdense_mhlo( + data, indices, indptr = csr_fromdense_hlo( mat, nnz=nse, index_dtype=np.dtype(index_dtype), data_dtype=dtype, index_type=mlir.dtype_to_ir_type(np.dtype(index_dtype))) return [data, indices, indptr] @@ -412,7 +412,7 @@ def _csr_matvec_abstract_eval(data, indices, indptr, v, *, shape, transpose): _csr_matvec_lowering = mlir.lower_fun(_csr_matvec_impl, multiple_results=False) -def _csr_matvec_gpu_lowering(csr_matvec_mhlo, ctx, data, indices, indptr, v, *, +def _csr_matvec_gpu_lowering(csr_matvec_hlo, ctx, data, indices, indptr, v, *, shape, transpose): data_aval, indices_aval, _, v_aval = ctx.avals_in dtype = data_aval.dtype @@ -421,7 +421,7 @@ def _csr_matvec_gpu_lowering(csr_matvec_mhlo, ctx, data, indices, indptr, v, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matvec_lowering(ctx, data, indices, indptr, v, shape=shape, transpose=transpose) - return [csr_matvec_mhlo( + return [csr_matvec_hlo( data, indices, indptr, v, shape=shape, transpose=transpose, data_dtype=dtype, index_dtype=indices_aval.dtype, x_dtype=v_aval.dtype)] @@ -504,7 +504,7 @@ def _csr_matmat_abstract_eval(data, indices, indptr, B, *, shape, transpose): _csr_matmat_lowering = mlir.lower_fun(_csr_matmat_impl, multiple_results=False) -def _csr_matmat_gpu_lowering(csr_matmat_mhlo, ctx, data, indices, indptr, B, *, +def _csr_matmat_gpu_lowering(csr_matmat_hlo, ctx, data, indices, indptr, B, *, shape, transpose): data_aval, indices_aval, _, B_aval = ctx.avals_in dtype = data_aval.dtype @@ -513,7 +513,7 @@ def _csr_matmat_gpu_lowering(csr_matmat_mhlo, ctx, data, indices, indptr, B, *, "Falling back to default implementation.", CuSparseEfficiencyWarning) return _csr_matmat_lowering(ctx, data, indices, indptr, B, shape=shape, transpose=transpose) - return [csr_matmat_mhlo( + return [csr_matmat_hlo( data, indices, indptr, B, shape=shape, transpose=transpose, index_dtype=indices_aval.dtype, data_dtype=data_aval.dtype, B_dtype=B_aval.dtype)] diff --git a/jax/interpreters/mlir.py b/jax/interpreters/mlir.py index a9992609641d..2a5551e842ff 100644 --- a/jax/interpreters/mlir.py +++ b/jax/interpreters/mlir.py @@ -12,8 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Lowering and execution path that converts jaxprs into the MLIR MHLO/CHLO -# dialects. +# Lowering and execution path that converts jaxprs into MLIR. from __future__ import annotations import collections @@ -35,8 +34,7 @@ from jax._src import dtypes from jax._src.lib import mlir_api_version, xla_extension_version from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import chlo -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.lib.mlir.dialects import func as func_dialect from jax._src.lib import can_execute_with_token from jax._src.lib import xla_bridge as xb @@ -86,12 +84,12 @@ def lower_dim(d): if type(d) is int: return ir_constant(np.array([d], np.int32)) else: - return mhlo.ReshapeOp(int1d, mhlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d)) + return hlo.ReshapeOp(int1d, hlo.ConvertOp(aval_to_ir_type(core.ShapedArray((), np.int32)), d)) d, *ds = map(lower_dim, sizes) if not ds: return d else: - return mhlo.ConcatenateOp([d, *ds], i64_attr(0)).result + return hlo.ConcatenateOp([d, *ds], i64_attr(0)).result def delegate_lowering(ctx, lowering_fun, *args, **ctx_override_kwargs): @@ -162,7 +160,7 @@ def aval_to_ir_types(aval: core.AbstractValue) -> Sequence[ir.Type]: ir_type_handlers[core.ShapedArray] = _array_ir_types ir_type_handlers[core.ConcreteArray] = _array_ir_types -ir_type_handlers[core.AbstractToken] = lambda _: [mhlo.TokenType.get()] +ir_type_handlers[core.AbstractToken] = lambda _: [hlo.TokenType.get()] ir_type_handlers[core.DShapedArray] = _dynamic_array_ir_types def aval_to_ir_type(aval: core.AbstractValue) -> ir.Type: @@ -239,7 +237,7 @@ def _numpy_array_constant(x: np.ndarray, canonicalize_types x = x.view(np.uint16) x = np.ascontiguousarray(x) attr = ir.DenseElementsAttr.get(x, type=element_type, shape=shape) - return (mhlo.ConstantOp(attr).result,) + return (hlo.ConstantOp(attr).result,) @@ -272,7 +270,7 @@ def _ndarray_constant_handler(val: np.ndarray, canonicalize_types if canonicalize_types: collapsed_val = np.asarray( collapsed_val, dtypes.canonicalize_dtype(collapsed_val.dtype)) - out = mhlo.BroadcastInDimOp( + out = hlo.BroadcastInDimOp( ir.RankedTensorType.get( val.shape, dtype_to_ir_type(collapsed_val.dtype)), _numpy_array_constant(collapsed_val, canonicalize_types=False)[0], @@ -304,9 +302,9 @@ def _device_array_constant_handler(val, canonicalize_types): def _token_constant_handler(val, canonicalize_types): if mlir_api_version < 40: - return [mhlo.CreateTokenOp(mhlo.TokenType.get()).result] + return [hlo.CreateTokenOp(hlo.TokenType.get()).result] else: - return [mhlo.CreateTokenOp().result] + return [hlo.CreateTokenOp().result] register_constant_handler(core.Token, _token_constant_handler) # Source locations @@ -331,12 +329,12 @@ def _source_info_to_location( # Translation rules def make_ir_context() -> ir.Context: """Creates an MLIR context suitable for JAX IR.""" + from jax._src.lib.mlir import dialects context = ir.Context() - mhlo.register_mhlo_dialect(context) - chlo.register_dialect(context) + dialects.mhlo.register_mhlo_dialect(context) + dialects.chlo.register_dialect(context) if mlir_api_version >= 37: - from jax._src.lib.mlir.dialects import stablehlo - stablehlo.register_dialect(context) + dialects.stablehlo.register_dialect(context) return context @@ -581,18 +579,18 @@ def __init__(self, value: ir.Value): def __add__(self, other: Union[np.int32, DimPolyEvaluator]): if not isinstance(other, DimPolyEvaluator): other = DimPolyEvaluator(ir_constant(other)) - return DimPolyEvaluator(mhlo.AddOp(self.value, other.value).result) + return DimPolyEvaluator(hlo.AddOp(self.value, other.value).result) def __radd__(self, other: np.int32): - return DimPolyEvaluator(mhlo.AddOp(ir_constant(other), self.value).result) + return DimPolyEvaluator(hlo.AddOp(ir_constant(other), self.value).result) def __mul__(self, other: Union[np.int32, DimPolyEvaluator]): if not isinstance(other, DimPolyEvaluator): other = DimPolyEvaluator(ir_constant(other)) - return DimPolyEvaluator(mhlo.MulOp(self.value, other.value).result) + return DimPolyEvaluator(hlo.MulOp(self.value, other.value).result) def __rmul__(self, other: np.int32): - return DimPolyEvaluator(mhlo.MulOp(ir_constant(other), self.value).result) + return DimPolyEvaluator(hlo.MulOp(ir_constant(other), self.value).result) def eval_dynamic_shape(ctx: LoweringRuleContext, @@ -640,7 +638,7 @@ def lower_jaxpr_to_module( arg_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None, result_shardings: Optional[Sequence[Optional[xc.OpSharding]]] = None ) -> LoweringResult: - """Lowers a top-level jaxpr to an MHLO module. + """Lowers a top-level jaxpr to an MLIR module. Handles the quirks of the argument/return value passing conventions of the runtime. @@ -678,7 +676,7 @@ def lower_jaxpr_to_module( msg = f"Donation is not implemented for {platform}.\n{msg}" warnings.warn(f"Some donated buffers were not usable: {', '.join(unused_donations)}.\n{msg}") - # MHLO channels need to start at 1 + # HLO channels need to start at 1 channel_iter = itertools.count(1) # Create a keepalives list that will be mutated during the lowering. keepalives: List[Any] = [] @@ -761,20 +759,20 @@ def _set_up_aliases(avals_in, avals_out, donated_args): Token = Sequence[ir.Value] def token_type() -> Sequence[ir.Type]: - return [mhlo.TokenType.get()] + return [hlo.TokenType.get()] def create_token() -> Token: if mlir_api_version < 40: return wrap_singleton_ir_values( - mhlo.CreateTokenOp(mhlo.TokenType.get()).result) + hlo.CreateTokenOp(hlo.TokenType.get()).result) else: - return wrap_singleton_ir_values(mhlo.CreateTokenOp().result) + return wrap_singleton_ir_values(hlo.CreateTokenOp().result) class TokenSet: """An immutable container of tokens to be used to lower effectful jaxprs. When lowering - effectful jaxprs, we need to thread MHLO tokens to sequence them. Each effect + effectful jaxprs, we need to thread HLO tokens to sequence them. Each effect will need its own token that will be threaded in and out of the effectful - primitives. A `TokenSet` encapsulates a set of MHLO tokens that will be + primitives. A `TokenSet` encapsulates a set of HLO tokens that will be used by the lowering rules. """ _tokens: typing.OrderedDict[core.Effect, Token] @@ -850,18 +848,18 @@ def lower_jaxpr_to_fun( jaxpr: the jaxpr to lower. effects: a sequence of `core.Effect`s corresponding to an ordering of tokens that will be created in or used by the lowered function. - create_tokens: if true, the MHLO will create tokens and ignore dummy input tokens. + create_tokens: if true, the HLO will create tokens and ignore dummy input tokens. public: if true, the function's visibility is set to "public". replace_tokens_with_dummy: if true, token arguments/return values are replaced with bool arrays of size [0]. replicated_args: if present, annotates arguments as replicated. arg_shardings: sharding annotations for each argument (optional). result_shardings: sharding annotations for each argument (optional). - use_sharding_annotations: if True, use mhlo.sharding annotations on + use_sharding_annotations: if True, use "mhlo.sharding" annotations on parameters and return values to express sharding. If False, use - mhlo.custom_call operators with sharding annotations. - TODO(b/228598865): remove this option when mhlo.sharding annotations are - propagated on non-entry functions during MHLO->HLO conversion. + hlo.custom_call operators with sharding annotations. + TODO(b/228598865): remove this option when "mhlo.sharding" annotations are + propagated on non-entry functions during MLIR->HLO conversion. input_output_aliases: optional sequence that maps argument numbers to the corresponding output that should alias them. Returns the name of the function. @@ -988,9 +986,9 @@ def aval_to_types(aval): for aval, arg in zip(jaxpr.in_avals, unflattened_args): if replace_tokens_with_dummy and aval is core.abstract_token: if mlir_api_version < 40: - args.append(mhlo.CreateTokenOp(mhlo.TokenType.get()).results) + args.append(hlo.CreateTokenOp(hlo.TokenType.get()).results) else: - args.append(mhlo.CreateTokenOp().results) + args.append(hlo.CreateTokenOp().results) else: args.append(arg) callee_name_stack = xla.extend_name_stack(ctx.name_stack, @@ -1061,7 +1059,7 @@ def jaxpr_subcomp(ctx: ModuleContext, jaxpr: core.Jaxpr, *args: Sequence[ir.Value], dim_var_values: Sequence[ir.Value] ) -> Tuple[Sequence[Sequence[ir.Value]], TokenSet]: - """Lowers a jaxpr into mHLO, inlined into an existing function. + """Lowers a jaxpr into MLIR, inlined into an existing function. Assumes that an MLIR context, location, and insertion point are set. @@ -1269,13 +1267,13 @@ def broadcast_in_dim(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue, broadcast_dimensions=broadcast_dimensions) if not core.is_constant_shape(aval_out.shape): # type: ignore shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore - return mhlo.DynamicBroadcastInDimOp( + return hlo.DynamicBroadcastInDimOp( aval_to_ir_type(aval_out), op, shape_tensor(shape), dense_int_elements(broadcast_dimensions), ).result else: - return mhlo.BroadcastInDimOp( + return hlo.BroadcastInDimOp( aval_to_ir_type(aval_out), op, dense_int_elements(broadcast_dimensions)).result @@ -1302,12 +1300,12 @@ def reshape(ctx: LoweringRuleContext, op, aval_out: core.AbstractValue) -> ir.Va aval_out, = aval_out.dtype._rules.physical_avals(aval_out) # type: ignore if not core.is_constant_shape(aval_out.shape): # type: ignore shape = eval_dynamic_shape(ctx, aval_out.shape) # type: ignore - return mhlo.DynamicReshapeOp( + return hlo.DynamicReshapeOp( aval_to_ir_type(aval_out), op, shape_tensor(shape), ).result else: - return mhlo.ReshapeOp(aval_to_ir_type(aval_out), op).result + return hlo.ReshapeOp(aval_to_ir_type(aval_out), op).result def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices, limit_indices, strides) -> ir.Value: @@ -1319,16 +1317,16 @@ def slice_op(ctx: LoweringRuleContext, x, aval_out, *, start_indices = eval_dynamic_shape(ctx, start_indices) limit_indices = eval_dynamic_shape(ctx, limit_indices) strides = eval_dynamic_shape(ctx, strides) - return mhlo.RealDynamicSliceOp(aval_to_ir_type(aval_out), - x, - shape_tensor(start_indices), - shape_tensor(limit_indices), - shape_tensor(strides)).result + return hlo.RealDynamicSliceOp(aval_to_ir_type(aval_out), + x, + shape_tensor(start_indices), + shape_tensor(limit_indices), + shape_tensor(strides)).result else: - return mhlo.SliceOp(x, - dense_int_elements(start_indices), - dense_int_elements(limit_indices), - dense_int_elements(strides)).result + return hlo.SliceOp(x, + dense_int_elements(start_indices), + dense_int_elements(limit_indices), + dense_int_elements(strides)).result def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, start_indices) -> ir.Value: @@ -1338,15 +1336,15 @@ def dynamic_slice(ctx: LoweringRuleContext, aval_out, x, *, slice_sizes = aval_out.shape if not core.is_constant_shape(slice_sizes): slice_sizes = eval_dynamic_shape(ctx, slice_sizes) - return mhlo.RealDynamicSliceOp( + return hlo.RealDynamicSliceOp( aval_to_ir_type(aval_out), x, shape_tensor(start_indices), shape_tensor(slice_sizes), shape_tensor([1] * len(slice_sizes)) ).result else: - return mhlo.DynamicSliceOp(x, start_indices, - dense_int_elements(slice_sizes)).result + return hlo.DynamicSliceOp(x, start_indices, + dense_int_elements(slice_sizes)).result def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, start_indices) -> ir.Value: @@ -1356,10 +1354,10 @@ def dynamic_update_slice(ctx: LoweringRuleContext, aval_out, x, update, *, # TODO(necula): handle dynamic shapes if mlir_api_version < 40: - return mhlo.DynamicUpdateSliceOp( + return hlo.DynamicUpdateSliceOp( aval_to_ir_type(aval_out), x, update, start_indices).result else: - return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result + return hlo.DynamicUpdateSliceOp(x, update, start_indices).result def full_like_aval(ctx: LoweringRuleContext, value, aval: core.ShapedArray) -> ir.Value: """Returns an IR constant shaped full of `value` shaped like `aval`.""" @@ -1373,14 +1371,14 @@ def zeros_like_lowering(ctx, x): register_lowering(ad_util.zeros_like_p, zeros_like_lowering) def add_jaxvals_lowering(ctx, x, y): - return mhlo.AddOp(x, y).results + return hlo.AddOp(x, y).results register_lowering(ad_util.add_jaxvals_p, add_jaxvals_lowering) register_lowering(ad_util.stop_gradient_p, lambda ctx, x: [x]) -def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None): - """Creates mhlo.CompareOp.""" +def compare_hlo(x, y, direction: str, comparison_type: Optional[str] = None): + """Creates CompareOp.""" if comparison_type is None: elem_type = ir.RankedTensorType(x.type).element_type if ir.IntegerType.isinstance(elem_type): @@ -1389,34 +1387,34 @@ def compare_mhlo(x, y, direction: str, comparison_type: Optional[str] = None): else: comparison_type = "FLOAT" - return mhlo.CompareOp( + return hlo.CompareOp( x, y, - mhlo.ComparisonDirectionAttr.get(direction), - compare_type=mhlo.ComparisonTypeAttr.get(comparison_type)) + hlo.ComparisonDirectionAttr.get(direction), + compare_type=hlo.ComparisonTypeAttr.get(comparison_type)) -def _minmax_mhlo(op, cmp, x, y): +def _minmax_hlo(op, cmp, x, y): """Min/max that compares complex values lexicographically as pairs.""" tensor_type = ir.RankedTensorType(x.type) if ir.ComplexType.isinstance(tensor_type.element_type): - rx = mhlo.RealOp(x).result - ry = mhlo.RealOp(y).result - real_eq = compare_mhlo(rx, ry, "EQ", "FLOAT") - real_cmp = compare_mhlo(rx, ry, cmp, "FLOAT") - imag_cmp = compare_mhlo( - mhlo.ImagOp(x).result, - mhlo.ImagOp(y).result, cmp, "FLOAT") - which = mhlo.SelectOp(real_eq, imag_cmp, real_cmp).result - return mhlo.SelectOp(which, x, y) + rx = hlo.RealOp(x).result + ry = hlo.RealOp(y).result + real_eq = compare_hlo(rx, ry, "EQ", "FLOAT") + real_cmp = compare_hlo(rx, ry, cmp, "FLOAT") + imag_cmp = compare_hlo( + hlo.ImagOp(x).result, + hlo.ImagOp(y).result, cmp, "FLOAT") + which = hlo.SelectOp(real_eq, imag_cmp, real_cmp).result + return hlo.SelectOp(which, x, y) else: return op(x, y) -min_mhlo = partial(_minmax_mhlo, mhlo.MinOp, "LT") -max_mhlo = partial(_minmax_mhlo, mhlo.MaxOp, "GT") +min_hlo = partial(_minmax_hlo, hlo.MinOp, "LT") +max_hlo = partial(_minmax_hlo, hlo.MaxOp, "GT") -def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out): - """Variant of convert that has XLA HLO semantics. +def convert_hlo(ctx: LoweringRuleContext, x, aval_in, aval_out): + """Variant of convert that has HLO semantics. In particular, treat casts to boolean as x != 0, rather than truncating integer values (b/209440332).""" @@ -1428,9 +1426,9 @@ def convert_mhlo(ctx: LoweringRuleContext, x, aval_in, aval_out): compare_type = "SIGNED" else: compare_type = "UNSIGNED" - return compare_mhlo(x, full_like_aval(ctx, 0, aval_in), "NE", - compare_type).result - return mhlo.ConvertOp(aval_to_ir_type(aval_out), x).result + return compare_hlo(x, full_like_aval(ctx, 0, aval_in), "NE", + compare_type).result + return hlo.ConvertOp(aval_to_ir_type(aval_out), x).result def _wrap_with_spmd_op(name: str, result_type: ir.Type, @@ -1444,14 +1442,14 @@ def _wrap_with_spmd_op(name: str, [str(i) for i in sorted(unspecified_dims)]) + "]" else: backend_config = "" - op = mhlo.CustomCallOp([result_type], [x], - call_target_name=ir.StringAttr.get(name), - has_side_effect=ir.BoolAttr.get(False), - backend_config=ir.StringAttr.get(backend_config), - api_version=i32_attr(1), - called_computations=ir.ArrayAttr.get([]), - operand_layouts=None, - result_layouts=None) + op = hlo.CustomCallOp([result_type], [x], + call_target_name=ir.StringAttr.get(name), + has_side_effect=ir.BoolAttr.get(False), + backend_config=ir.StringAttr.get(backend_config), + api_version=i32_attr(1), + called_computations=ir.ArrayAttr.get([]), + operand_layouts=None, + result_layouts=None) op.attributes["mhlo.sharding"] = ir.StringAttr.get( sharding_proto.SerializeToString()) return op.result @@ -1489,7 +1487,7 @@ def cached_lowering(ctx, *args, **params): except TypeError: # If the parameters aren't hashable, give up on caching. # TODO(phawkins): switch to requiring hashability, when XLA fallback - # computations have been ported to MHLO. + # computations have been ported to MLIR. return f(ctx, *args, **params) if func is None: func = _emit_lowering_rule_as_fun(partial(f, **params), ctx) @@ -1506,12 +1504,12 @@ def cached_lowering(ctx, *args, **params): -def xla_computation_to_mhlo_module(xla_computation: xc.XlaComputation +def xla_computation_to_mlir_module(xla_computation: xc.XlaComputation ) -> ir.Module: module_str = xc._xla.mlir.xla_computation_to_mlir_module(xla_computation) return ir.Module.parse(module_str) -def merge_mhlo_modules(dst_module: ir.Module, +def merge_mlir_modules(dst_module: ir.Module, sym_name: str, src_module: ir.Module) -> str: """Returns the name of src_module's main() function, after renaming.""" @@ -1563,8 +1561,8 @@ def fallback(ctx: LoweringRuleContext, *args, **params): xla_computation = xla.primitive_subcomputation( module_ctx.platform, axis_env, prim, ctx.avals_in, ctx.avals_out, **params) - xla_module = xla_computation_to_mhlo_module(xla_computation) - callee_name = merge_mhlo_modules( + xla_module = xla_computation_to_mlir_module(xla_computation) + callee_name = merge_mlir_modules( module_ctx.module, f"xla_fallback_{prim.name}", xla_module) output_types = map(aval_to_ir_types, ctx.avals_out) flat_output_types = util.flatten(output_types) @@ -1576,7 +1574,7 @@ def fallback(ctx: LoweringRuleContext, *args, **params): flatten_lowering_ir_args(args)).result if not prim.multiple_results: return [call] - flat_results = [mhlo.GetTupleElementOp(call, i32_attr(i)).result + flat_results = [hlo.GetTupleElementOp(call, i32_attr(i)).result for i in range(len(flat_output_types))] return util.unflatten(flat_results, map(len, output_types)) @@ -1611,16 +1609,16 @@ def _dtype_to_xla_type_string(dtype: np.dtype) -> str: raise NotImplementedError(dtype) return _dtype_to_xla_type_string_map[dtype] -def send_to_host(channel: int, token: mhlo.TokenType, operand: Any, +def send_to_host(channel: int, token: hlo.TokenType, operand: Any, aval: core.ShapedArray, name: str, *, sharding: Optional[xc.OpSharding] = None) -> ir.Value: - channel_handle = mhlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) + channel_handle = hlo.ChannelHandle.get(channel, SEND_TO_HOST_TYPE) if mlir_api_version < 40: - send_op = mhlo.SendOp(mhlo.TokenType.get(), [operand], token, channel_handle, - is_host_transfer=ir.BoolAttr.get(True)) + send_op = hlo.SendOp(hlo.TokenType.get(), [operand], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) else: - send_op = mhlo.SendOp([operand], token, channel_handle, - is_host_transfer=ir.BoolAttr.get(True)) + send_op = hlo.SendOp([operand], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) dtype_str = _dtype_to_xla_type_string(aval.dtype) if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: raise NotImplementedError("64-bit types not supported.") @@ -1634,13 +1632,13 @@ def send_to_host(channel: int, token: mhlo.TokenType, operand: Any, return send_op.result -def receive_from_host(channel: int, token: mhlo.TokenType, +def receive_from_host(channel: int, token: hlo.TokenType, out_aval: core.ShapedArray, name: str, *, sharding: Optional[xc.OpSharding] = None) -> ir.Value: - channel_handle = mhlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) - recv_op = mhlo.RecvOp([aval_to_ir_type(out_aval), - mhlo.TokenType.get()], token, channel_handle, - is_host_transfer=ir.BoolAttr.get(True)) + channel_handle = hlo.ChannelHandle.get(channel, RECV_FROM_HOST_TYPE) + recv_op = hlo.RecvOp([aval_to_ir_type(out_aval), + hlo.TokenType.get()], token, channel_handle, + is_host_transfer=ir.BoolAttr.get(True)) dtype_str = _dtype_to_xla_type_string(out_aval.dtype) if dtype_str in {"f64", "s64", "u64", "c64", "c128"}: raise NotImplementedError("64-bit types not supported.") @@ -1670,9 +1668,9 @@ def _emit_tpu_python_callback( sharding: Optional[xc.OpSharding] = None ) -> Tuple[List[ir.Value], Any, Any]: if mlir_api_version < 40: - token = token or mhlo.CreateTokenOp(mhlo.TokenType.get()).result + token = token or hlo.CreateTokenOp(hlo.TokenType.get()).result else: - token = token or mhlo.CreateTokenOp().result + token = token or hlo.CreateTokenOp().result _wrapped_callback = callback send_channels = [] @@ -1680,7 +1678,7 @@ def _emit_tpu_python_callback( # If there are no operands to the callback, we need to insert a dummy send # op or the callback will never be triggered! # TODO(sharadmv,chky): Enable this fix in the runtime as opposed to in - # MHLO builder. + # MLIR builder. callback_without_args = _wrapped_callback def _wrapped_callback(*args): # pylint: disable=function-redefined del args @@ -1761,7 +1759,7 @@ def emit_python_callback( operand_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None, result_layouts: Optional[Sequence[Optional[Sequence[int]]]] = None, ) -> Tuple[List[ir.Value], Any, Any]: - """Emits MHLO that calls back to a provided Python function.""" + """Emits MLIR that calls back to a provided Python function.""" platform = ctx.module_context.platform if platform not in {"cpu", "cuda", "rocm", "tpu"}: raise ValueError( @@ -1842,7 +1840,7 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function result_type = ir.TupleType.get_tuple(result_types) call_target_name = ("xla_python_gpu_callback" if platform in {"cuda", "rocm"} else "xla_python_cpu_callback") - result = mhlo.CustomCallOp( + result = hlo.CustomCallOp( [result_type], callback_operands, call_target_name=ir.StringAttr.get(call_target_name), @@ -1859,7 +1857,7 @@ def _wrapped_callback(token, *args): # type: ignore # pylint: disable=function if sharding is not None: set_sharding(result, sharding) results = [ - mhlo.GetTupleElementOp(result, i32_attr(i)).result + hlo.GetTupleElementOp(result, i32_attr(i)).result for i in range(len(result_types)) ] if token: diff --git a/jax/interpreters/pxla.py b/jax/interpreters/pxla.py index 4fbab1e37d5f..32de2bce87a9 100644 --- a/jax/interpreters/pxla.py +++ b/jax/interpreters/pxla.py @@ -77,7 +77,7 @@ from jax._src.lib import xla_client as xc from jax._src.lib import pmap_lib from jax._src.lib.mlir import ir -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src.util import (unzip3, prod, safe_map, safe_zip, partition_list, new_name_stack, wrap_name, assert_unreachable, tuple_insert, tuple_delete, distributed_debug_log, @@ -2211,14 +2211,14 @@ def _pmap_dce_rule(used_outputs, eqn): ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p) -def _unravel_index_mhlo(axis_env): +def _unravel_index_hlo(axis_env): div = mlir.ir_constant( np.array(axis_env.nreps // util.prod(axis_env.sizes), np.uint32)) mod = mlir.ir_constant(np.array(axis_env.sizes[-1], np.uint32)) - return mhlo.RemOp( - mhlo.DivOp(mhlo.ReplicaIdOp().result, div).result, mod).result + return hlo.RemOp( + hlo.DivOp(hlo.ReplicaIdOp().result, div).result, mod).result -def _mhlo_shard(aval, axis_env, xs, in_axis): +def _hlo_shard(aval, axis_env, xs, in_axis): if aval is core.abstract_token: return xs elif isinstance(aval, core.ShapedArray): @@ -2226,20 +2226,20 @@ def _mhlo_shard(aval, axis_env, xs, in_axis): dims = list(aval.shape) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) idxs = [zero] * len(dims) - idxs.insert(in_axis, _unravel_index_mhlo(axis_env)) + idxs.insert(in_axis, _unravel_index_hlo(axis_env)) dims_unsqueezed = dims.copy() dims_unsqueezed.insert(in_axis, 1) - dynamic_slice_result = mhlo.DynamicSliceOp( + dynamic_slice_result = hlo.DynamicSliceOp( x, idxs, mlir.dense_int_elements(dims_unsqueezed)).result return [ - mhlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result + hlo.ReshapeOp(mlir.aval_to_ir_type(aval), dynamic_slice_result).result ] else: raise TypeError(aval) # TODO(b/110096942): more efficient gather -def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform): +def _hlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, platform): if aval is core.abstract_token: return xs elif isinstance(aval, core.ShapedArray): @@ -2249,23 +2249,23 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p and platform in ('cpu', 'gpu')) if convert_bool: aval = aval.update(dtype=np.dtype(np.float32)) - x = mhlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result + x = hlo.ConvertOp(mlir.aval_to_ir_type(aval), x).result dims = list(aval.shape) padded_aval = aval.update(shape=[axis_env.sizes[-1]] + dims) padded = mlir.full_like_aval(ctx, 0, padded_aval) zero = mlir.ir_constant(np.zeros((), dtype=np.uint32)) - idxs = [_unravel_index_mhlo(axis_env)] + [zero] * len(dims) - broadcast_result = mhlo.BroadcastOp( + idxs = [_unravel_index_hlo(axis_env)] + [zero] * len(dims) + broadcast_result = hlo.BroadcastOp( x, mlir.dense_int_elements([1])).result if xc.mlir_api_version < 40: - padded = mhlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result, - idxs).result + padded = hlo.DynamicUpdateSliceOp(padded.type, padded, broadcast_result, + idxs).result else: - padded = mhlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result + padded = hlo.DynamicUpdateSliceOp(padded, broadcast_result, idxs).result replica_groups = mlir.dense_int_elements( xla.axis_groups(axis_env, axis_env.names[-1])) - out = mhlo.CrossReplicaSumOp(padded, replica_groups).result + out = hlo.CrossReplicaSumOp(padded, replica_groups).result if out_axis != 0: # TODO(apaszke,mattjj): Change the indices to DynamicUpdateSlice instead perm = list(range(1, len(dims))) @@ -2273,16 +2273,16 @@ def _mhlo_unshard(ctx: mlir.LoweringRuleContext, aval, axis_env, out_axis, xs, p transposed_dims = list(dims) transposed_dims.insert(out_axis, axis_env.sizes[-1]) aval = aval.update(shape=transposed_dims) - out = mhlo.TransposeOp(out, mlir.dense_int_elements(perm)).result + out = hlo.TransposeOp(out, mlir.dense_int_elements(perm)).result # TODO(mattjj): remove this logic when AllReduce PRED supported on CPU / GPU if convert_bool: float_zero = mlir.full_like_aval(ctx, 0, padded_aval) - out = mhlo.CompareOp( + out = hlo.CompareOp( out, float_zero, - mhlo.ComparisonDirectionAttr.get("NE"), - compare_type=mhlo.ComparisonTypeAttr.get("FLOAT")).result + hlo.ComparisonDirectionAttr.get("NE"), + compare_type=hlo.ComparisonTypeAttr.get("FLOAT")).result return out else: raise TypeError(aval) @@ -2305,7 +2305,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, # Shard the in_nodes that are mapped in_avals = [v.aval for v in call_jaxpr.invars] in_nodes_sharded = ( - _mhlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis) + _hlo_shard(aval, new_env, mlir.wrap_singleton_ir_values(in_node), in_axis) if in_axis is not None else mlir.wrap_singleton_ir_values(in_node) for aval, in_node, in_axis in zip(in_avals, in_nodes, in_axes)) @@ -2318,7 +2318,7 @@ def _pmap_lowering(ctx, *in_nodes, axis_name, *in_nodes_sharded, dim_var_values=ctx.dim_var_values) out_avals = [v.aval for v in call_jaxpr.outvars] - outs = [_mhlo_unshard(ctx, aval, new_env, out_axis, shard, + outs = [_hlo_unshard(ctx, aval, new_env, out_axis, shard, platform=ctx.module_context.platform) for aval, out_axis, shard in zip(out_avals, out_axes, sharded_outs)] return outs diff --git a/jax/interpreters/xla.py b/jax/interpreters/xla.py index d619ca9285ea..a9e3a0103e3b 100644 --- a/jax/interpreters/xla.py +++ b/jax/interpreters/xla.py @@ -586,6 +586,6 @@ def lower_fun(fun: Callable, *, multiple_results: bool, backend=None, def f(*args, **kw): raise RuntimeError("XLA translation rules are deprecated and " "jax.interpreters.xla.lower_fun is no longer supported. " - "Add an MLIR (MHLO) lowering via jax.interpreters.mlir " + "Add an MLIR lowering via jax.interpreters.mlir " "instead.") return f diff --git a/jaxlib/BUILD b/jaxlib/BUILD index 3bbc6f8ce883..764583e35d51 100644 --- a/jaxlib/BUILD +++ b/jaxlib/BUILD @@ -34,9 +34,9 @@ py_library( "gpu_rnn.py", "gpu_solver.py", "gpu_sparse.py", + "hlo_helpers.py", "init.py", "lapack.py", - "mhlo_helpers.py", ":version", ":xla_client", ], diff --git a/jaxlib/ducc_fft.py b/jaxlib/ducc_fft.py index c4e5477bf545..fd1c63d9803d 100644 --- a/jaxlib/ducc_fft.py +++ b/jaxlib/ducc_fft.py @@ -15,10 +15,10 @@ from typing import List import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as mhlo +import jaxlib.mlir.dialects.mhlo as hlo -from .mhlo_helpers import custom_call +from .hlo_helpers import custom_call from .cpu import _ducc_fft import numpy as np @@ -107,7 +107,11 @@ def _ducc_fft_descriptor(shape: List[int], dtype, fft_type: FftType, return descriptor, out_dtype, out_shape +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): + return ducc_fft_hlo(a, dtype, fft_type=fft_type, fft_lengths=fft_lengths) + +def ducc_fft_hlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): """DUCC FFT kernel for CPU.""" a_type = ir.RankedTensorType(a.type) n = len(a_type.shape) @@ -128,15 +132,15 @@ def ducc_fft_mhlo(a, dtype, *, fft_type: FftType, fft_lengths: List[int]): raise ValueError(f"Unknown output type {out_dtype}") if 0 in a_type.shape or 0 in out_shape: - zero = mhlo.ConstantOp( + zero = hlo.ConstantOp( ir.DenseElementsAttr.get( np.array(0, dtype=out_dtype), type=out_type)) - return mhlo.BroadcastOp( + return hlo.BroadcastOp( zero, ir.DenseElementsAttr.get(np.asarray(out_shape, np.int64))).result u8_type = ir.IntegerType.get_unsigned(8) - descriptor = mhlo.ConstantOp( + descriptor = hlo.ConstantOp( ir.DenseElementsAttr.get( np.frombuffer(descriptor_bytes, dtype=np.uint8), type=u8_type)) layout = tuple(range(n - 1, -1, -1)) diff --git a/jaxlib/gpu_linalg.py b/jaxlib/gpu_linalg.py index 1a0755c008e7..575a808524f0 100644 --- a/jaxlib/gpu_linalg.py +++ b/jaxlib/gpu_linalg.py @@ -18,7 +18,7 @@ import jaxlib.mlir.ir as ir -from .mhlo_helpers import custom_call +from .hlo_helpers import custom_call from jaxlib import xla_client @@ -39,7 +39,7 @@ _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_size): +def _lu_pivots_to_permutation_hlo(platform, gpu_linalg, pivots, *, permutation_size): """Kernel for the transformation of pivots to permutations on GPU.""" typ = ir.RankedTensorType(pivots.type) dims = typ.shape @@ -65,7 +65,7 @@ def _lu_pivots_to_permutation_mhlo(platform, gpu_linalg, pivots, *, permutation_ operand_layouts=[pivots_layout], result_layouts=[permutations_layout]) -cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_mhlo, "cu", +cuda_lu_pivots_to_permutation = partial(_lu_pivots_to_permutation_hlo, "cu", _cuda_linalg) hip_lu_pivots_to_permutation = partial( - _lu_pivots_to_permutation_mhlo, "hip", _hip_linalg) + _lu_pivots_to_permutation_hlo, "hip", _hip_linalg) diff --git a/jaxlib/gpu_prng.py b/jaxlib/gpu_prng.py index 9f9135315ae3..0696a6a7a771 100644 --- a/jaxlib/gpu_prng.py +++ b/jaxlib/gpu_prng.py @@ -22,7 +22,7 @@ from jaxlib import xla_client -from .mhlo_helpers import custom_call +from .hlo_helpers import custom_call try: from .cuda import _prng as _cuda_prng diff --git a/jaxlib/gpu_rnn.py b/jaxlib/gpu_rnn.py index dba1a686e9b2..40654dd1e44a 100644 --- a/jaxlib/gpu_rnn.py +++ b/jaxlib/gpu_rnn.py @@ -13,7 +13,7 @@ # limitations under the License. import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as mhlo +import jaxlib.mlir.dialects.mhlo as hlo import numpy as np @@ -61,7 +61,7 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, i32_type = ir.IntegerType.get_signless(32) - out = mhlo.CustomCallOp( + out = hlo.CustomCallOp( [ ir.TupleType.get_tuple([ output_type, h_0.type, c_0.type, workspace_type, @@ -76,13 +76,13 @@ def cudnn_rnn_lowering(ctx, input, h_0, c_0, weights, seq_lengths, *, called_computations=ir.ArrayAttr.get([]), ) return [ - mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result + hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(5) ] -def _mhlo_zeros_f32(shape): - return mhlo.ConstantOp( +def _hlo_zeros_f32(shape): + return hlo.ConstantOp( ir.DenseElementsAttr.get( np.zeros(shape, dtype=np.float32), type=ir.F32Type.get())).result @@ -102,8 +102,8 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space_shape[0]) i32_type = ir.IntegerType.get_signless(32) - zeroed_dw = _mhlo_zeros_f32(ctx.avals_out[3].shape) - out = mhlo.CustomCallOp( + zeroed_dw = _hlo_zeros_f32(ctx.avals_out[3].shape) + out = hlo.CustomCallOp( [ir.TupleType.get_tuple([x.type, h0.type, c0.type, w.type])], [ dy, dhn, dcn, x, h0, c0, w, y, workspace, reserve_space, zeroed_dw, seq_lengths @@ -114,12 +114,12 @@ def cudnn_rnn_bwd_lowering(ctx, dy, dhn, dcn, x, h0, c0, w, y, workspace, api_version=ir.IntegerAttr.get(i32_type, 2), called_computations=ir.ArrayAttr.get([]), output_operand_aliases=ir.ArrayAttr.get([ - mhlo.OutputOperandAlias.get( + hlo.OutputOperandAlias.get( output_tuple_indices=[3], operand_index=10, operand_tuple_indices=[]) ])) return [ - mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result + hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(4) ] diff --git a/jaxlib/gpu_solver.py b/jaxlib/gpu_solver.py index 464e8d9a09b1..7f6ca0414e80 100644 --- a/jaxlib/gpu_solver.py +++ b/jaxlib/gpu_solver.py @@ -18,13 +18,13 @@ import operator import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as mhlo +import jaxlib.mlir.dialects.mhlo as hlo import numpy as np from jaxlib import xla_client -from .mhlo_helpers import custom_call +from .hlo_helpers import custom_call try: from .cuda import _blas as _cublas @@ -63,7 +63,7 @@ def _real_type(dtype): _prod = lambda xs: functools.reduce(operator.mul, xs, 1) -def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a): +def _getrf_hlo(platform, gpu_blas, gpu_solver, dtype, a): """LU decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -106,11 +106,11 @@ def _getrf_mhlo(platform, gpu_blas, gpu_solver, dtype, a): operand_output_aliases={0: 0}) return out[:3] -cuda_getrf = partial(_getrf_mhlo, "cu", _cublas, _cusolver) -rocm_getrf = partial(_getrf_mhlo, "hip", _hipblas, _hipsolver) +cuda_getrf = partial(_getrf_hlo, "cu", _cublas, _cusolver) +rocm_getrf = partial(_getrf_hlo, "hip", _hipblas, _hipsolver) -def _geqrf_mhlo(platform, gpu_solver, dtype, a): +def _geqrf_hlo(platform, gpu_solver, dtype, a): """QR decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -145,10 +145,10 @@ def _geqrf_mhlo(platform, gpu_solver, dtype, a): operand_output_aliases={0: 0}) return out[:3] -cuda_geqrf = partial(_geqrf_mhlo, "cu", _cusolver) -rocm_geqrf = partial(_geqrf_mhlo, "hip", _hipsolver) +cuda_geqrf = partial(_geqrf_hlo, "cu", _cusolver) +rocm_geqrf = partial(_geqrf_hlo, "hip", _hipsolver) -def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a): +def _geqrf_batched_hlo(platform, gpu_blas, dtype, a): """Batched QR decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -183,12 +183,12 @@ def _geqrf_batched_mhlo(platform, gpu_blas, dtype, a): ) return out[:2] -cuda_geqrf_batched = partial(_geqrf_batched_mhlo, "cu", _cublas) -rocm_geqrf_batched = partial(_geqrf_batched_mhlo, "hip", _hipblas) +cuda_geqrf_batched = partial(_geqrf_batched_hlo, "cu", _cublas) +rocm_geqrf_batched = partial(_geqrf_batched_hlo, "hip", _hipblas) -def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data, - indices, indptr, b, tol, reorder): +def _csrlsvqr_hlo(platform, gpu_solver, dtype, data, + indices, indptr, b, tol, reorder): """Sparse solver via QR decomposition. CUDA only.""" b_type = ir.RankedTensorType(b.type) data_type = ir.RankedTensorType(data.type) @@ -209,10 +209,10 @@ def _csrlsvqr_mhlo(platform, gpu_solver, dtype, data, ) return [out] -cuda_csrlsvqr = partial(_csrlsvqr_mhlo, "cu", _cusolver) +cuda_csrlsvqr = partial(_csrlsvqr_hlo, "cu", _cusolver) -def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau): +def _orgqr_hlo(platform, gpu_solver, dtype, a, tau): """Product of elementary Householder reflections.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -252,12 +252,12 @@ def _orgqr_mhlo(platform, gpu_solver, dtype, a, tau): operand_output_aliases={0: 0}) return out[:2] -cuda_orgqr = partial(_orgqr_mhlo, "cu", _cusolver) -rocm_orgqr = partial(_orgqr_mhlo, "hip", _hipsolver) +cuda_orgqr = partial(_orgqr_hlo, "cu", _cusolver) +rocm_orgqr = partial(_orgqr_hlo, "hip", _hipsolver) -def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, - lower=False): +def _syevd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, + lower=False): """Symmetric (Hermitian) eigendecomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -304,12 +304,12 @@ def _syevd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, operand_output_aliases={0: 0}) return out[:3] -cuda_syevd = partial(_syevd_mhlo, "cu", _cusolver, True) -rocm_syevd = partial(_syevd_mhlo, "hip", _hipsolver, True) +cuda_syevd = partial(_syevd_hlo, "cu", _cusolver, True) +rocm_syevd = partial(_syevd_hlo, "hip", _hipsolver, True) -def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, - full_matrices=True, compute_uv=True): +def _gesvd_hlo(platform, gpu_solver, have_jacobi_solver, dtype, a, + full_matrices=True, compute_uv=True): """Singular value decomposition.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -358,18 +358,18 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, [0], ], operand_output_aliases={0: 0}) - vt = mhlo.TransposeOp( + vt = hlo.TransposeOp( v, ir.DenseIntElementsAttr.get(np.array(tuple(range(num_bd)) + (num_bd + 1, num_bd)))).result if np.issubdtype(dtype, np.complexfloating): - vt = mhlo.ComplexOp(mhlo.RealOp(vt), mhlo.NegOp(mhlo.ImagOp(vt))).result + vt = hlo.ComplexOp(hlo.RealOp(vt), hlo.NegOp(hlo.ImagOp(vt))).result if not full_matrices and not econ: - u = mhlo.SliceOp( + u = hlo.SliceOp( u, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get(np.array(batch_dims + (m, min(m, n)))), ir.DenseIntElementsAttr.get(np.ones([len(dims)], np.int64))).result - vt = mhlo.SliceOp( + vt = hlo.SliceOp( vt, ir.DenseIntElementsAttr.get(np.zeros([len(dims)], np.int64)), ir.DenseIntElementsAttr.get(np.array(batch_dims + (min(m, n), n))), @@ -430,11 +430,11 @@ def _gesvd_mhlo(platform, gpu_solver, have_jacobi_solver, dtype, a, operand_output_aliases={0: 0}) return s, u, vt, info -cuda_gesvd = partial(_gesvd_mhlo, "cu", _cusolver, True) -rocm_gesvd = partial(_gesvd_mhlo, "hip", _hipsolver, False) +cuda_gesvd = partial(_gesvd_hlo, "cu", _cusolver, True) +rocm_gesvd = partial(_gesvd_hlo, "hip", _hipsolver, False) -def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower): +def _sytrd_hlo(platform, gpu_solver, dtype, a, *, lower): """sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form.""" a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -490,20 +490,20 @@ def _sytrd_mhlo(platform, gpu_solver, dtype, a, *, lower): if not lower and platform == "cu" and m > 1: start = (0,) * len(batch_dims) + (0,) end = batch_dims + (1,) - s = mhlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start))) + s = hlo.SliceOp(e, intattr(start), intattr(end), intattr([1] * len(start))) s_type = ir.RankedTensorType.get(batch_dims + (1, 1), diag_type) - s = mhlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1))) + s = hlo.BroadcastInDimOp(s_type, s, intattr(range(len(dims) - 1))) # The diagonals are always real; convert to complex if needed. - s = mhlo.ConvertOp( + s = hlo.ConvertOp( ir.RankedTensorType.get(s_type.shape, a_type.element_type), s) - offsets = tuple(mhlo.ConstantOp(intattr(i)) + offsets = tuple(hlo.ConstantOp(intattr(i)) for i in ((0,) * len(batch_dims) + (0, 1))) if xla_client.mlir_api_version < 40: - a = mhlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result + a = hlo.DynamicUpdateSliceOp(a.type, a, s, offsets).result else: - a = mhlo.DynamicUpdateSliceOp(a, s, offsets).result + a = hlo.DynamicUpdateSliceOp(a, s, offsets).result return a, d, e, taus, info -cuda_sytrd = partial(_sytrd_mhlo, "cu", _cusolver) -rocm_sytrd = partial(_sytrd_mhlo, "hip", _hipsolver) +cuda_sytrd = partial(_sytrd_hlo, "cu", _cusolver) +rocm_sytrd = partial(_sytrd_hlo, "hip", _hipsolver) diff --git a/jaxlib/gpu_sparse.py b/jaxlib/gpu_sparse.py index dd374bb8ae32..f75a514451cf 100644 --- a/jaxlib/gpu_sparse.py +++ b/jaxlib/gpu_sparse.py @@ -23,7 +23,7 @@ from jaxlib import xla_client -from .mhlo_helpers import custom_call +from .hlo_helpers import custom_call try: from .cuda import _sparse as _cusparse @@ -46,7 +46,7 @@ rocm_is_supported : bool = _hipsparse and _hipsparse.sparse_supported -def _validate_csr_mhlo(data, indices, indptr, shape): +def _validate_csr_hlo(data, indices, indptr, shape): data_type = ir.RankedTensorType(data.type) indices_type = ir.RankedTensorType(indices.type) indptr_type = ir.RankedTensorType(indptr.type) @@ -57,7 +57,7 @@ def _validate_csr_mhlo(data, indices, indptr, shape): assert indptr_type.shape == [shape[0] + 1] return data_type.element_type, indices_type.element_type, nnz -def _validate_coo_mhlo(data, row, col): +def _validate_coo_hlo(data, row, col): data_type = ir.RankedTensorType(data.type) row_type = ir.RankedTensorType(row.type) col_type = ir.RankedTensorType(col.type) @@ -69,10 +69,10 @@ def _validate_coo_mhlo(data, row, col): return data_type.element_type, row_type.element_type, nnz -def _csr_todense_mhlo(platform, gpu_sparse, data, indices, indptr, *, shape, - data_dtype, index_dtype): +def _csr_todense_hlo(platform, gpu_sparse, data, indices, indptr, *, shape, + data_dtype, index_dtype): """CSR to dense matrix.""" - data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape) + data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) rows, cols = shape buffer_size, opaque = gpu_sparse.build_csr_todense_descriptor( @@ -91,12 +91,12 @@ def _csr_todense_mhlo(platform, gpu_sparse, data, indices, indptr, *, shape, result_layouts=[[1, 0], [0]]) return out[0] -cuda_csr_todense = partial(_csr_todense_mhlo, "cu", _cusparse) -rocm_csr_todense = partial(_csr_todense_mhlo, "hip", _hipsparse) +cuda_csr_todense = partial(_csr_todense_hlo, "cu", _cusparse) +rocm_csr_todense = partial(_csr_todense_hlo, "hip", _hipsparse) -def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype, - data_dtype, index_type): +def _csr_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, index_dtype, + data_dtype, index_type): """CSR from dense matrix.""" mat_type = ir.RankedTensorType(mat.type) rows, cols = mat_type.shape @@ -119,15 +119,15 @@ def _csr_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, index_dtype, result_layouts=[[0]] * 4) return out[:3] -cuda_csr_fromdense = partial(_csr_fromdense_mhlo, "cu", _cusparse) -rocm_csr_fromdense = partial(_csr_fromdense_mhlo, "hip", _hipsparse) +cuda_csr_fromdense = partial(_csr_fromdense_hlo, "cu", _cusparse) +rocm_csr_fromdense = partial(_csr_fromdense_hlo, "hip", _hipsparse) -def _csr_matvec_mhlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - data_dtype, index_dtype, x_dtype): +def _csr_matvec_hlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, + transpose=False, compute_dtype=None, compute_type=None, + data_dtype, index_dtype, x_dtype): """CSR matrix/vector multiply.""" - data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape) + data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) rows, cols = shape if compute_dtype is None: @@ -152,15 +152,15 @@ def _csr_matvec_mhlo(platform, gpu_sparse, data, indices, indptr, x, *, shape, result_layouts=[[0]] * 2) return out[0] -cuda_csr_matvec = partial(_csr_matvec_mhlo, "cu", _cusparse) -rocm_csr_matvec = partial(_csr_matvec_mhlo, "hip", _hipsparse) +cuda_csr_matvec = partial(_csr_matvec_hlo, "cu", _cusparse) +rocm_csr_matvec = partial(_csr_matvec_hlo, "hip", _hipsparse) -def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, B_dtype): +def _csr_matmat_hlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, + transpose=False, compute_dtype=None, compute_type=None, + index_dtype, data_dtype, B_dtype): """CSR from dense matrix.""" - data_type, index_type, nnz = _validate_csr_mhlo(data, indices, indptr, shape) + data_type, index_type, nnz = _validate_csr_hlo(data, indices, indptr, shape) rows, cols = shape B_shape = ir.RankedTensorType(B.type).shape _, Ccols = B_shape @@ -187,14 +187,14 @@ def _csr_matmat_mhlo(platform, gpu_sparse, data, indices, indptr, B, *, shape, result_layouts=[[1, 0], [0]]) return out[0] -cuda_csr_matmat = partial(_csr_matmat_mhlo, "cu", _cusparse) -rocm_csr_matmat = partial(_csr_matmat_mhlo, "hip", _hipsparse) +cuda_csr_matmat = partial(_csr_matmat_hlo, "cu", _cusparse) +rocm_csr_matmat = partial(_csr_matmat_hlo, "hip", _hipsparse) -def _coo_todense_mhlo(platform, gpu_sparse, data, row, col, *, shape, - data_dtype, index_dtype): +def _coo_todense_hlo(platform, gpu_sparse, data, row, col, *, shape, + data_dtype, index_dtype): """COO to dense matrix.""" - data_type, _, nnz = _validate_coo_mhlo(data, row, col) + data_type, _, nnz = _validate_coo_hlo(data, row, col) rows, cols = shape buffer_size, opaque = gpu_sparse.build_coo_todense_descriptor( @@ -213,12 +213,12 @@ def _coo_todense_mhlo(platform, gpu_sparse, data, row, col, *, shape, result_layouts=[[1, 0], [0]]) return out[0] -cuda_coo_todense = partial(_coo_todense_mhlo, "cu", _cusparse) -rocm_coo_todense = partial(_coo_todense_mhlo, "hip", _hipsparse) +cuda_coo_todense = partial(_coo_todense_hlo, "cu", _cusparse) +rocm_coo_todense = partial(_coo_todense_hlo, "hip", _hipsparse) -def _coo_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, data_dtype, - index_dtype, index_type): +def _coo_fromdense_hlo(platform, gpu_sparse, mat, *, nnz, data_dtype, + index_dtype, index_type): """COO from dense matrix.""" mat_type = ir.RankedTensorType(mat.type) rows, cols = mat_type.shape @@ -241,15 +241,15 @@ def _coo_fromdense_mhlo(platform, gpu_sparse, mat, *, nnz, data_dtype, result_layouts=[[0]] * 4) return out[:3] -cuda_coo_fromdense = partial(_coo_fromdense_mhlo, "cu", _cusparse) -rocm_coo_fromdense = partial(_coo_fromdense_mhlo, "hip", _hipsparse) +cuda_coo_fromdense = partial(_coo_fromdense_hlo, "cu", _cusparse) +rocm_coo_fromdense = partial(_coo_fromdense_hlo, "hip", _hipsparse) -def _coo_matvec_mhlo(platform, gpu_sparse, data, row, col, x, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - index_dtype, data_dtype, x_dtype): +def _coo_matvec_hlo(platform, gpu_sparse, data, row, col, x, *, shape, + transpose=False, compute_dtype=None, compute_type=None, + index_dtype, data_dtype, x_dtype): """COO matrix/vector multiply.""" - data_type, _, nnz = _validate_coo_mhlo(data, row, col) + data_type, _, nnz = _validate_coo_hlo(data, row, col) rows, cols = shape if compute_dtype is None: @@ -274,15 +274,15 @@ def _coo_matvec_mhlo(platform, gpu_sparse, data, row, col, x, *, shape, result_layouts=[[0]] * 2) return out[0] -cuda_coo_matvec = partial(_coo_matvec_mhlo, "cu", _cusparse) -rocm_coo_matvec = partial(_coo_matvec_mhlo, "hip", _hipsparse) +cuda_coo_matvec = partial(_coo_matvec_hlo, "cu", _cusparse) +rocm_coo_matvec = partial(_coo_matvec_hlo, "hip", _hipsparse) -def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape, - transpose=False, compute_dtype=None, compute_type=None, - x_dtype, data_dtype, index_dtype): +def _coo_matmat_hlo(platform, gpu_sparse, data, row, col, B, *, shape, + transpose=False, compute_dtype=None, compute_type=None, + x_dtype, data_dtype, index_dtype): """COO from dense matrix.""" - data_type, _, nnz = _validate_coo_mhlo(data, row, col) + data_type, _, nnz = _validate_coo_hlo(data, row, col) is_batched_matmat = False batch_count = 1 if len(shape) == 2: @@ -334,11 +334,11 @@ def _coo_matmat_mhlo(platform, gpu_sparse, data, row, col, B, *, shape, result_layouts=[out_layout, [0]]) return out[0] -cuda_coo_matmat = partial(_coo_matmat_mhlo, "cu", _cusparse) -rocm_coo_matmat = partial(_coo_matmat_mhlo, "hip", _hipsparse) +cuda_coo_matmat = partial(_coo_matmat_hlo, "cu", _cusparse) +rocm_coo_matmat = partial(_coo_matmat_hlo, "hip", _hipsparse) -def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t): +def _gtsv2_hlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t): """Calls `cusparsegtsv2(dl, d, du, B, m, n, ldb)`.""" f32 = (t == np.float32) if f32: @@ -360,5 +360,5 @@ def _gtsv2_mhlo(platform, gpu_sparse, dl, d, du, B, *, m, n, ldb, t): operand_output_aliases={3: 0}) return out[0] -cuda_gtsv2 = partial(_gtsv2_mhlo, "cu", _cusparse) -rocm_gtsv2 = partial(_gtsv2_mhlo, "hip", _hipsparse) +cuda_gtsv2 = partial(_gtsv2_hlo, "cu", _cusparse) +rocm_gtsv2 = partial(_gtsv2_hlo, "hip", _hipsparse) diff --git a/jaxlib/mhlo_helpers.py b/jaxlib/hlo_helpers.py similarity index 90% rename from jaxlib/mhlo_helpers.py rename to jaxlib/hlo_helpers.py index 904d9bba3a6d..e029df9b8251 100644 --- a/jaxlib/mhlo_helpers.py +++ b/jaxlib/hlo_helpers.py @@ -12,11 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Helpers for building MHLO operators +# Helpers for building MLIR operators from typing import Dict, Optional, Sequence, Union import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as mhlo +import jaxlib.mlir.dialects.mhlo as hlo import numpy as np @@ -31,7 +31,7 @@ def custom_call( api_version: int = 2, operand_output_aliases: Dict[int, int] = {}, ) -> Union[ir.Value, Sequence[ir.Value]]: - """Less-verbose helper for building an MHLO custom call op. + """Less-verbose helper for building a CustomCallOp. Once https://github.com/llvm/llvm-project/issues/54932 is fixed, this helper may be able to go away. @@ -42,7 +42,7 @@ def custom_call( that must alias. """ i32_type = ir.IntegerType.get_signless(32) - out = mhlo.CustomCallOp( + out = hlo.CustomCallOp( (out_types if len(out_types) == 1 else [ir.TupleType.get_tuple(out_types)]), operands, @@ -63,7 +63,7 @@ def custom_call( type=ir.IndexType.get()) for l in result_layouts ]), output_operand_aliases=ir.ArrayAttr.get([ - mhlo.OutputOperandAlias.get( + hlo.OutputOperandAlias.get( output_tuple_indices=[] if len(out_types) == 1 else [output], operand_index=input, operand_tuple_indices=[]) @@ -73,6 +73,6 @@ def custom_call( return out.result else: return [ - mhlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result + hlo.GetTupleElementOp(out, ir.IntegerAttr.get(i32_type, i)).result for i in range(len(out_types)) ] diff --git a/jaxlib/lapack.py b/jaxlib/lapack.py index 03694600b1ef..f728cff0bbf7 100644 --- a/jaxlib/lapack.py +++ b/jaxlib/lapack.py @@ -16,12 +16,12 @@ # via CustomCallWithLayout. import jaxlib.mlir.ir as ir -import jaxlib.mlir.dialects.mhlo as mhlo +import jaxlib.mlir.dialects.mhlo as hlo import numpy as np from jaxlib import xla_client -from .mhlo_helpers import custom_call +from .hlo_helpers import custom_call from .cpu import _lapack for _name, _value in _lapack.registrations().items(): @@ -32,24 +32,30 @@ _initialize = _lapack.initialize -def _mhlo_u8(x): - return mhlo.ConstantOp( +def _hlo_u8(x): + return hlo.ConstantOp( ir.DenseElementsAttr.get( np.array(x, dtype=np.uint8), type=ir.IntegerType.get_unsigned(8))).result -def _mhlo_s32(x): - return mhlo.ConstantOp( +def _hlo_s32(x): + return hlo.ConstantOp( ir.DenseElementsAttr.get( np.array(x, dtype=np.int32), type=ir.IntegerType.get_signless(32))).result # TODO(phawkins): it would be nice to avoid duplicating code for each type. +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. +def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, + conj_a=False, diag=False): + return trsm_hlo(dtype, alpha, a, b, left_side=left_side, lower=lower, + trans_a=trans_a, conj_a=conj_a, diag=diag) + # ?trsm(left_side, lower, trans_a, diag, m, n, alpha, a, b): # triangular solve -def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, - conj_a=False, diag=False): +def trsm_hlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, + conj_a=False, diag=False): _initialize() a_type = ir.RankedTensorType(a.type) b_type = ir.RankedTensorType(b.type) @@ -87,9 +93,9 @@ def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, return custom_call( fn, [b.type], - [_mhlo_s32(int(left_side)), _mhlo_s32(int(lower)), - _mhlo_s32((2 if conj_a else 1) if trans_a else 0), _mhlo_s32(int(diag)), - _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(num_b), + [_hlo_s32(int(left_side)), _hlo_s32(int(lower)), + _hlo_s32((2 if conj_a else 1) if trans_a else 0), _hlo_s32(int(diag)), + _hlo_s32(m), _hlo_s32(n), _hlo_s32(num_b), alpha, a, b], operand_layouts=[scalar_layout] * 8 + [layout] * 2, result_layouts=[layout], @@ -99,7 +105,11 @@ def trsm_mhlo(dtype, alpha, a, b, left_side=False, lower=False, trans_a=False, # # ?getrf: LU decomposition +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def getrf_mhlo(dtype, a): + return getrf_hlo(dtype, a) + +def getrf_hlo(dtype, a): _initialize() dims = ir.RankedTensorType(a.type).shape assert len(dims) >= 2 @@ -131,7 +141,7 @@ def getrf_mhlo(dtype, a): ir.RankedTensorType.get(batch_dims + (min(m, n),), i32_type), ir.RankedTensorType.get(batch_dims, i32_type), ], - [_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), a], + [_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), a], operand_layouts=[scalar_layout] * 3 + [layout], result_layouts=[ layout, @@ -144,7 +154,11 @@ def getrf_mhlo(dtype, a): # # ?geqrf: QR decomposition +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def geqrf_mhlo(dtype, a): + return geqrf_hlo(dtype, a) + +def geqrf_hlo(dtype, a): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -182,7 +196,7 @@ def geqrf_mhlo(dtype, a): ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], - [_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a], + [_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a], operand_layouts=[scalar_layout] * 4 + [layout], result_layouts=[ layout, @@ -197,7 +211,11 @@ def geqrf_mhlo(dtype, a): # # ?orgqr: product of elementary Householder reflectors: +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def orgqr_mhlo(dtype, a, tau): + return orgqr_hlo(dtype, a, tau) + +def orgqr_hlo(dtype, a, tau): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -238,8 +256,8 @@ def orgqr_mhlo(dtype, a, tau): ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], - [_mhlo_s32(int(b)), _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(k), - _mhlo_s32(lwork), a, tau], + [_hlo_s32(int(b)), _hlo_s32(m), _hlo_s32(n), _hlo_s32(k), + _hlo_s32(lwork), a, tau], operand_layouts=[scalar_layout] * 5 + [ layout, tuple(range(num_bd, -1, -1)), @@ -256,7 +274,11 @@ def orgqr_mhlo(dtype, a, tau): # ?potrf: Cholesky decomposition +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def potrf_mhlo(dtype, a, lower=False): + return potrf_hlo(dtype, a, lower=lower) + +def potrf_hlo(dtype, a, lower=False): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -286,7 +308,7 @@ def potrf_mhlo(dtype, a, lower=False): fn, [a.type, ir.RankedTensorType.get(batch_dims, ir.IntegerType.get_signless(32))], - [_mhlo_s32(int(lower)), _mhlo_s32(b), _mhlo_s32(n), a], + [_hlo_s32(int(lower)), _hlo_s32(b), _hlo_s32(n), a], operand_layouts=[scalar_layout] * 3 + [layout], result_layouts=[layout, info_layout], operand_output_aliases={3: 0}, @@ -297,7 +319,11 @@ def potrf_mhlo(dtype, a, lower=False): # # ?gesdd: Singular value decomposition +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True): + return gesdd_hlo(dtype, a, full_matrices=full_matrices, compute_uv=compute_uv) + +def gesdd_hlo(dtype, a, full_matrices=True, compute_uv=True): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -370,8 +396,8 @@ def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True): a_type.element_type), ir.RankedTensorType.get(batch_dims, i32_type), ] + workspace, - [_mhlo_s32(int(full_matrices)), _mhlo_s32(int(compute_uv)), _mhlo_s32(b), - _mhlo_s32(m), _mhlo_s32(n), _mhlo_s32(lwork), a], + [_hlo_s32(int(full_matrices)), _hlo_s32(int(compute_uv)), _hlo_s32(b), + _hlo_s32(m), _hlo_s32(n), _hlo_s32(lwork), a], operand_layouts=[scalar_layout] * 6 + [layout], result_layouts=[ layout, @@ -387,7 +413,11 @@ def gesdd_mhlo(dtype, a, full_matrices=True, compute_uv=True): # # syevd: Symmetric eigendecomposition +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def syevd_mhlo(dtype, a, lower=False): + return syevd_hlo(dtype, a, lower=lower) + +def syevd_hlo(dtype, a, lower=False): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -452,7 +482,7 @@ def syevd_mhlo(dtype, a, lower=False): ir.RankedTensorType.get(batch_dims + (n,), eigvals_type), ir.RankedTensorType.get(batch_dims, i32_type), ] + workspace, - [_mhlo_s32(1 if lower else 0), _mhlo_s32(b), _mhlo_s32(n), a], + [_hlo_s32(1 if lower else 0), _hlo_s32(b), _hlo_s32(n), a], operand_layouts=[scalar_layout] * 3 + [layout], result_layouts=[ layout, @@ -466,7 +496,11 @@ def syevd_mhlo(dtype, a, lower=False): # # geev: Nonsymmetric eigendecomposition +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def geev_mhlo(dtype, a, jobvl=True, jobvr=True): + return geev_hlo(dtype, a, jobvl=jobvl, jobvr=jobvr) + +def geev_hlo(dtype, a, jobvl=True, jobvr=True): _initialize() dims = ir.RankedTensorType(a.type).shape assert len(dims) >= 2 @@ -539,19 +573,23 @@ def geev_mhlo(dtype, a, jobvl=True, jobvr=True): ir.RankedTensorType.get(dims, eigvecs_type), ir.RankedTensorType.get(batch_dims, i32_type), ], - [_mhlo_s32(b), _mhlo_s32(n), _mhlo_u8(jobvl_c), _mhlo_u8(jobvr_c), a], + [_hlo_s32(b), _hlo_s32(n), _hlo_u8(jobvl_c), _hlo_u8(jobvr_c), a], operand_layouts=[scalar_layout] * 4 + [layout], result_layouts=(workspace_layouts + eigvals_layouts + [layout] * 2 + [info_layout]) ) if real: - return (mhlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7]) + return (hlo.ComplexOp(out[3], out[4]).result, out[5], out[6], out[7]) else: return out[2:6] # # gees : Schur factorization +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None): + return gees_hlo(dtype, a, jobvs=jobvs, sort=sort, select=select) + +def gees_hlo(dtype, a, jobvs=True, sort=False, select=None): _initialize() a_type = ir.RankedTensorType(a.type) etype = a_type.element_type @@ -609,10 +647,10 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None): ir.RankedTensorType.get(batch_dims, i32_type), ], [ - _mhlo_s32(b), - _mhlo_s32(n), - _mhlo_u8(np.uint8(jobvs)), - _mhlo_u8(np.uint8(sort)), + _hlo_s32(b), + _hlo_s32(n), + _hlo_u8(np.uint8(jobvs)), + _hlo_u8(np.uint8(sort)), # TODO: figure out how to put the callable select function here a ], @@ -630,8 +668,12 @@ def gees_mhlo(dtype, a, jobvs=True, sort=False, select=None): return (out[0], out[3], out[5]) -# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form. +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def gehrd_mhlo(dtype, a): + return gehrd_hlo(dtype, a) + +# gehrd: Reduction of a non-symmetric square matrix to upper Hessenberg form. +def gehrd_hlo(dtype, a): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -669,8 +711,8 @@ def gehrd_mhlo(dtype, a): ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], - [_mhlo_s32(n), _mhlo_s32(1), _mhlo_s32(n), _mhlo_s32(n), _mhlo_s32(b), - _mhlo_s32(lwork), a], + [_hlo_s32(n), _hlo_s32(1), _hlo_s32(n), _hlo_s32(n), _hlo_s32(b), + _hlo_s32(lwork), a], operand_layouts=[[]] * 6 + [layout], result_layouts=[ layout, @@ -683,8 +725,12 @@ def gehrd_mhlo(dtype, a): return out[:3] -# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. +# TODO(burmako): Remove this compatibility shim when mlir_api_version >= 41. def sytrd_mhlo(dtype, a, *, lower): + return sytrd_hlo(dtype, a, lower=lower) + +# sytrd: Reduction of a symmetric (Hermitian) matrix to tridiagonal form. +def sytrd_hlo(dtype, a, *, lower): _initialize() a_type = ir.RankedTensorType(a.type) dims = a_type.shape @@ -728,8 +774,8 @@ def sytrd_mhlo(dtype, a, *, lower): ir.RankedTensorType.get(batch_dims, i32_type), ir.RankedTensorType.get([lwork], a_type.element_type), ], - [_mhlo_s32(n), _mhlo_s32(1 if lower else 0), _mhlo_s32(max(1, n)), - _mhlo_s32(b), _mhlo_s32(lwork), a], + [_hlo_s32(n), _hlo_s32(1 if lower else 0), _hlo_s32(max(1, n)), + _hlo_s32(b), _hlo_s32(lwork), a], operand_layouts=[[]] * 5 + [layout], result_layouts=[ layout, diff --git a/tests/custom_object_test.py b/tests/custom_object_test.py index c3a131adbea0..1f324fae8920 100644 --- a/tests/custom_object_test.py +++ b/tests/custom_object_test.py @@ -155,10 +155,10 @@ def _sp_indices_abstract_eval(mat): # Note: cannot use lower_fun to define attribute access primitives # because it leads to infinite recursion. -def _sp_indices_mhlo_lowering(ctx, data_and_indices): +def _sp_indices_hlo_lowering(ctx, data_and_indices): return [data_and_indices[1]] -mlir.register_lowering(sp_indices_p, _sp_indices_mhlo_lowering) +mlir.register_lowering(sp_indices_p, _sp_indices_hlo_lowering) sp_data_p = core.Primitive('sp_data') @@ -173,10 +173,10 @@ def _sp_data_abstract_eval(mat): # Note: cannot use lower_fun to define attribute access primitives # because it leads to infinite recursion. -def _sp_data_mhlo_lowering(ctx, data_and_indices): +def _sp_data_hlo_lowering(ctx, data_and_indices): return [data_and_indices[0]] -mlir.register_lowering(sp_data_p, _sp_data_mhlo_lowering) +mlir.register_lowering(sp_data_p, _sp_data_hlo_lowering) def identity(x): return identity_p.bind(x) diff --git a/tests/dynamic_api_test.py b/tests/dynamic_api_test.py index 1a677a2d33b9..6d0a1b40700e 100644 --- a/tests/dynamic_api_test.py +++ b/tests/dynamic_api_test.py @@ -767,7 +767,7 @@ def test_slicing_basic(self): self.assertAllClose(ans, expected, check_dtypes=True) # TODO(mattjj,dougalm,phawkins): debug iree failure, "failed to legalize - # operation 'mhlo.while' that was explicitly marked illegal" + # operation 'while' that was explicitly marked illegal" @unittest.skip("revising slicing logic") def test_scan_basic(self): def cumsum(x): @@ -1275,8 +1275,8 @@ def f(x): return x.sum() f_lowered = f.lower(np.arange(3, dtype='int32')) - mhlo = f_lowered.compiler_ir('mhlo') - self.assertIn('tensor', str(mhlo)) + mlir_str = f_lowered.compiler_ir() + self.assertIn('tensor', str(mlir_str)) def test_lower_abstracted_axes_shapedtypestruct(self): @partial(jax.jit, abstracted_axes=('n',)) @@ -1284,8 +1284,8 @@ def f(x): return x.sum() f_lowered = f.lower(jax.ShapeDtypeStruct((3,), np.int32)) - mhlo = f_lowered.compiler_ir('mhlo') - self.assertIn('tensor', str(mhlo)) + mlir_str = f_lowered.compiler_ir() + self.assertIn('tensor', str(mlir_str)) def test_vmap_abstracted_axis(self): def foo(x, y): diff --git a/tests/filecheck/README.md b/tests/filecheck/README.md index caf01d505c43..55766bc9041f 100644 --- a/tests/filecheck/README.md +++ b/tests/filecheck/README.md @@ -1,6 +1,6 @@ This directory contains LLVM [FileCheck](https://llvm.org/docs/CommandGuide/FileCheck.html) tests that verify -that JAX primitives can be lowered to MHLO. +that JAX primitives can be lowered to MLIR. These tests are intended to be a quick and easy-to-understand way to catch regressions from changes due the MLIR Python bindings and from changes to the diff --git a/tests/filecheck/array.filecheck.py b/tests/filecheck/array.filecheck.py index ba2e3392ecbe..4305a450dd61 100644 --- a/tests/filecheck/array.filecheck.py +++ b/tests/filecheck/array.filecheck.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Tests for lowering of array origami ops into MHLO. +# Tests for lowering of array origami ops into MLIR. # RUN: %PYTHON %s | FileCheck %s @@ -32,62 +32,62 @@ def main(_): # CHECK-LABEL: TEST: concatenate bool[2,7] bool[2,5] - # CHECK: mhlo.concatenate + # CHECK: hlo.concatenate # CHECK-SAME: tensor<2x12xi1> print_ir([np.empty([2, 7], np.bool_), np.empty([2, 5], np.bool_)])( partial(lax.concatenate, dimension=1)) # CHECK-LABEL: TEST: broadcast_in_dim bool[2,7] - # CHECK: mhlo.broadcast_in_dim + # CHECK: hlo.broadcast_in_dim # CHECK-SAME: tensor<3x2x5x7x2xi1> print_ir(np.empty([2, 7], np.bool_))( partial(lax.broadcast_in_dim, shape=(3, 2, 5, 7, 2), broadcast_dimensions=(1, 3))) # CHECK-LABEL: TEST: iota - # CHECK: mhlo.iota + # CHECK: hlo.iota # CHECK-SAME: tensor<10xf32> print_ir()(partial(lax.iota, dtype=np.float32, size=10)) # CHECK-LABEL: TEST: pad int32[2,7] - # CHECK: mhlo.pad + # CHECK: hlo.pad # CHECK-SAME: tensor<11x52xi32> print_ir(np.empty([2, 7], np.int32))( partial(lax.pad, padding_value=np.int32(7), padding_config=((2, 3, 4), (4, 5, 6)))) # CHECK-LABEL: TEST: _reduce_sum int32[2,3,7] - # CHECK: mhlo.reduce - # CHECK: mhlo.add + # CHECK: hlo.reduce + # CHECK: hlo.add # CHECK: tensor<3xi32> print_ir(np.empty([2, 3, 7], np.int32))( partial(lax_internal._reduce_sum, axes=(0, 2))) # CHECK-LABEL: TEST: reshape int32[2,3,7] - # CHECK: mhlo.reshape + # CHECK: hlo.reshape # CHECK-SAME: tensor<42xi32> print_ir(np.empty([2, 3, 7], np.int32))( partial(lax.reshape, new_sizes=(42,))) # CHECK-LABEL: TEST: rev int32[2,7] - # CHECK: mhlo.rev + # CHECK: hlo.rev # CHECK-SAME: tensor<2x7xi32> print_ir(np.empty([2, 7], np.int32))( partial(lax.rev, dimensions=(0, 1))) # CHECK-LABEL: TEST: select bool[2,7] int32[2,7] int32[2,7] - # CHECK: mhlo.select + # CHECK: hlo.select # CHECK-SAME: tensor<2x7xi1>, tensor<2x7xi32> print_ir(np.empty([2, 7], np.bool_), np.empty([2, 7], np.int32), np.empty([2, 7], np.int32))(lax.select) # CHECK-LABEL: TEST: sort int32[2,7] - # CHECK: mhlo.sort + # CHECK: hlo.sort # CHECK: tensor<2x7xi32> print_ir(np.empty([2, 7], np.int32))(lax.sort) # CHECK-LABEL: TEST: squeeze int32[2,1,7] - # CHECK: mhlo.reshape + # CHECK: hlo.reshape # CHECK-SAME: tensor<2x7xi32> print_ir(np.empty([2, 1, 7], np.int32))( partial(lax.squeeze, dimensions=(1,))) @@ -98,7 +98,7 @@ def main(_): print_ir(np.empty([2, 7], np.int32))(partial(lax.top_k, k=7)) # CHECK-LABEL: TEST: transpose int32[2,7] - # CHECK: mhlo.transpose + # CHECK: hlo.transpose # CHECK-SAME: tensor<7x2xi32> print_ir(np.empty([2, 7], np.int32))( partial(lax.transpose, permutation=(1, 0))) diff --git a/tests/filecheck/jax_filecheck_helpers.py b/tests/filecheck/jax_filecheck_helpers.py index f5837d980eef..8a044e1e1dea 100644 --- a/tests/filecheck/jax_filecheck_helpers.py +++ b/tests/filecheck/jax_filecheck_helpers.py @@ -20,7 +20,7 @@ def print_ir(*prototypes): def lower(f): - """Prints the MHLO IR that results from lowering `f`. + """Prints the MLIR IR that results from lowering `f`. The arguments to `f` are taken to be arrays shaped like `prototypes`.""" inputs = tree_util.tree_map(np.array, prototypes) @@ -29,5 +29,5 @@ def lower(f): for x in flat_inputs]) name = f.func.__name__ if hasattr(f, "func") else f.__name__ print(f"\nTEST: {name} {shape_strs}") - print(jax.jit(f).lower(*inputs).compiler_ir(dialect="mhlo")) + print(jax.jit(f).lower(*inputs).compiler_ir()) return lower diff --git a/tests/filecheck/math.filecheck.py b/tests/filecheck/math.filecheck.py index 86778448d10a..f63507a04c68 100644 --- a/tests/filecheck/math.filecheck.py +++ b/tests/filecheck/math.filecheck.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Tests for lowerings of elementwise ops to MHLO. +# Tests for lowerings of elementwise ops to MLIR. # RUN: %PYTHON %s | FileCheck %s @@ -31,17 +31,17 @@ def main(_): # CHECK-LABEL: TEST: abs int32[] - # CHECK: mhlo.abs + # CHECK: hlo.abs # CHECK-SAME: tensor print_ir(np.int32(0))(lax.abs) # CHECK-LABEL: TEST: add float32[] float32[] - # CHECK: mhlo.add + # CHECK: hlo.add # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.add) # CHECK-LABEL: TEST: acos float32[] - # CHECK: mhlo.atan2 + # CHECK: hlo.atan2 # CHECK-SAME: tensor print_ir(np.float32(1))(lax.acos) @@ -71,7 +71,7 @@ def main(_): print_ir(np.float32(0))(lax.atanh) # CHECK-LABEL: TEST: atan2 float64[] float64[] - # CHECK: mhlo.atan2 + # CHECK: hlo.atan2 # CHECK-SAME: tensor print_ir(np.float64(1), np.float64(2))(lax.atan2) @@ -91,93 +91,93 @@ def main(_): print_ir(np.float32(0), np.float32(0), np.float32(0))(lax.betainc) # CHECK-LABEL: TEST: bitcast_convert_type uint32[7] - # CHECK: mhlo.bitcast_convert + # CHECK: hlo.bitcast_convert # CHECK-SAME: tensor<7xui32> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.uint32))( partial(lax.bitcast_convert_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: bitwise_and int32[] int32[] - # CHECK: mhlo.and + # CHECK: hlo.and # CHECK-SAME: tensor print_ir(np.int32(1), np.int32(2))(lax.bitwise_and) # CHECK-LABEL: TEST: bitwise_and bool[] bool[] - # CHECK: mhlo.and + # CHECK: hlo.and # CHECK-SAME: tensor print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_and) # CHECK-LABEL: TEST: bitwise_or int32[] int32[] - # CHECK: mhlo.or + # CHECK: hlo.or # CHECK-SAME: tensor print_ir(np.int32(1), np.int32(2))(lax.bitwise_or) # CHECK-LABEL: TEST: bitwise_or bool[] bool[] - # CHECK: mhlo.or + # CHECK: hlo.or # CHECK-SAME: tensor print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_or) # CHECK-LABEL: TEST: bitwise_xor int32[] int32[] - # CHECK: mhlo.xor + # CHECK: hlo.xor # CHECK-SAME: tensor print_ir(np.int32(1), np.int32(2))(lax.bitwise_xor) # CHECK-LABEL: TEST: bitwise_xor bool[] bool[] - # CHECK: mhlo.xor + # CHECK: hlo.xor # CHECK-SAME: tensor print_ir(np.bool_(0), np.bool_(0))(lax.bitwise_xor) # CHECK-LABEL: TEST: cbrt bfloat16[] - # CHECK: mhlo.cbrt + # CHECK: hlo.cbrt # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.cbrt) # CHECK-LABEL: TEST: clamp bfloat16[] bfloat16[] bfloat16[] - # CHECK: mhlo.clamp + # CHECK: hlo.clamp # CHECK-SAME: tensor print_ir(jnp.bfloat16(0), jnp.bfloat16(0), jnp.bfloat16(0))(lax.clamp) # CHECK-LABEL: TEST: ceil float16[7] - # CHECK: mhlo.ceil + # CHECK: hlo.ceil # CHECK-SAME: tensor<7xf16> print_ir(np.empty((7,), np.float16))(lax.ceil) # CHECK-LABEL: TEST: convert_element_type float16[7] - # CHECK: mhlo.convert + # CHECK: hlo.convert # CHECK-SAME: tensor<7xf16> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.float16))( partial(lax.convert_element_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: convert_element_type complex64[7] - # CHECK: mhlo.real + # CHECK: hlo.real # CHECK-SAME: tensor<7xcomplex> # CHECK-SAME: tensor<7xf32> print_ir(np.empty((7,), np.complex64))( partial(lax.convert_element_type, new_dtype=np.float32)) # CHECK-LABEL: TEST: convert_element_type float32[7] - # CHECK: mhlo.compare + # CHECK: hlo.compare # CHECK-SAME: tensor<7xf32> # CHECK-SAME: tensor<7xi1> print_ir(np.empty((7,), np.float32))( partial(lax.convert_element_type, new_dtype=np.bool_)) # CHECK-LABEL: TEST: clz uint32[] - # CHECK: mhlo.count_leading_zeros + # CHECK: hlo.count_leading_zeros # CHECK-SAME: tensor print_ir(np.uint32(0))(lax.clz) # CHECK-LABEL: TEST: conj complex64[] - # CHECK-DAG: mhlo.real - # CHECK-DAG: mhlo.imag - # CHECK-DAG: mhlo.neg - # CHECK-DAG: mhlo.complex + # CHECK-DAG: hlo.real + # CHECK-DAG: hlo.imag + # CHECK-DAG: hlo.neg + # CHECK-DAG: hlo.complex # CHECK-SAME: tensor> print_ir(np.complex64(0))(lax.conj) # CHECK-LABEL: TEST: cos float32[] - # CHECK: mhlo.cos + # CHECK: hlo.cos # CHECK-SAME: tensor print_ir(np.float32(0))(lax.cos) @@ -192,30 +192,30 @@ def main(_): print_ir(np.float32(0))(lax.digamma) # CHECK-LABEL: TEST: div float32[] float32[] - # CHECK: mhlo.div + # CHECK: hlo.div # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.div) # CHECK-LABEL: TEST: eq float32[] float32[] - # CHECK: mhlo.compare EQ + # CHECK: hlo.compare EQ # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.eq) # CHECK-LABEL: TEST: eq complex128[] complex128[] - # CHECK: mhlo.compare EQ + # CHECK: hlo.compare EQ # CHECK-SAME: FLOAT # CHECK-SAME: tensor> print_ir(np.complex128(1), np.complex128(2))(lax.eq) # CHECK-LABEL: TEST: eq int64[] int64[] - # CHECK: mhlo.compare EQ + # CHECK: hlo.compare EQ # CHECK-SAME: SIGNED # CHECK-SAME: tensor print_ir(np.int64(1), np.int64(2))(lax.eq) # CHECK-LABEL: TEST: eq uint16[] uint16[] - # CHECK: mhlo.compare EQ + # CHECK: hlo.compare EQ # CHECK-SAME: UNSIGNED # CHECK-SAME: tensor print_ir(np.uint16(1), np.uint16(2))(lax.eq) @@ -236,28 +236,28 @@ def main(_): print_ir(np.float32(0))(lax.erf_inv) # CHECK-LABEL: TEST: exp float16[] - # CHECK: mhlo.exp + # CHECK: hlo.exp # CHECK-SAME: tensor print_ir(np.float16(0))(lax.exp) # CHECK-LABEL: TEST: expm1 bfloat16[] - # CHECK: mhlo.exponential_minus_one + # CHECK: hlo.exponential_minus_one # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.expm1) # CHECK-LABEL: TEST: floor bfloat16[2,3] - # CHECK: mhlo.floor + # CHECK: hlo.floor # CHECK-SAME: tensor<2x3xbf16> print_ir(np.empty((2, 3), jnp.bfloat16))(lax.floor) # CHECK-LABEL: TEST: ge float32[] float32[] - # CHECK: mhlo.compare GE + # CHECK: hlo.compare GE # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ge) # CHECK-LABEL: TEST: gt float32[] float32[] - # CHECK: mhlo.compare GT + # CHECK: hlo.compare GT # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.gt) @@ -278,23 +278,23 @@ def main(_): print_ir(np.float32(0), np.float32(0))(lax.igamma_grad_a) # CHECK-LABEL: TEST: imag complex64[] - # CHECK: mhlo.imag + # CHECK: hlo.imag # CHECK-SAME: tensor> print_ir(np.complex64(0))(lax.imag) # CHECK-LABEL: TEST: integer_pow float32[] - # CHECK-DAG: mhlo.mul + # CHECK-DAG: hlo.mul # CHECK-SAME: tensor @print_ir(np.float32(1)) def integer_pow(x): return lax.integer_pow(x, 3) # CHECK-LABEL: TEST: is_finite float64[] - # CHECK: mhlo.is_finite + # CHECK: hlo.is_finite # CHECK-SAME: tensor print_ir(np.float64(0))(lax.is_finite) # CHECK-LABEL: TEST: le float32[] float32[] - # CHECK: mhlo.compare LE + # CHECK: hlo.compare LE # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.le) @@ -305,44 +305,44 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(np.float32(0))(lax.lgamma) # CHECK-LABEL: TEST: log float32[] - # CHECK: mhlo.log + # CHECK: hlo.log # CHECK-SAME: tensor print_ir(np.float32(0))(lax.log) # CHECK-LABEL: TEST: log1p float32[] - # CHECK: mhlo.log_plus_one + # CHECK: hlo.log_plus_one # CHECK-SAME: tensor print_ir(np.float32(0))(lax.log1p) # CHECK-LABEL: TEST: lt float32[] float32[] - # CHECK: mhlo.compare LT + # CHECK: hlo.compare LT # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.lt) # CHECK-LABEL: TEST: max float32[] float32[] - # CHECK: mhlo.max + # CHECK: hlo.max # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.max) # CHECK-LABEL: TEST: min float32[] float32[] - # CHECK: mhlo.min + # CHECK: hlo.min # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.min) # CHECK-LABEL: TEST: mul float32[] float32[] - # CHECK: mhlo.mul + # CHECK: hlo.mul # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.mul) # CHECK-LABEL: TEST: ne float32[] float32[] - # CHECK: mhlo.compare NE + # CHECK: hlo.compare NE # CHECK-SAME: FLOAT # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.ne) # CHECK-LABEL: TEST: neg int64[] - # CHECK: mhlo.negate + # CHECK: hlo.negate # CHECK-SAME: tensor print_ir(np.int64(0))(lax.neg) @@ -352,22 +352,22 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(np.float32(0), np.float32(0))(lax.nextafter) # CHECK-LABEL: TEST: bitwise_not int64[] - # CHECK: mhlo.not + # CHECK: hlo.not # CHECK-SAME: tensor print_ir(np.int64(0))(lax.bitwise_not) # CHECK-LABEL: TEST: bitwise_not bool[] - # CHECK: mhlo.not + # CHECK: hlo.not # CHECK-SAME: tensor print_ir(np.bool_(0))(lax.bitwise_not) # CHECK-LABEL: TEST: population_count uint32[] - # CHECK: mhlo.popcnt + # CHECK: hlo.popcnt # CHECK-SAME: tensor print_ir(np.uint32(0))(lax.population_count) # CHECK-LABEL: TEST: pow float32[] float32[] - # CHECK: mhlo.power + # CHECK: hlo.power # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.pow) @@ -377,59 +377,59 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(np.float32(0), np.float32(0))(lax.random_gamma_grad) # CHECK-LABEL: TEST: real complex128[] - # CHECK: mhlo.real + # CHECK: hlo.real # CHECK-SAME: tensor> print_ir(np.complex128(0))(lax.real) # CHECK-LABEL: TEST: reduce_precision bfloat16[] - # CHECK: mhlo.reduce_precision + # CHECK: hlo.reduce_precision # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))( partial(lax.reduce_precision, exponent_bits=2, mantissa_bits=2)) # CHECK-LABEL: TEST: rem float32[] float32[] - # CHECK: mhlo.rem + # CHECK: hlo.rem # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.rem) # CHECK-LABEL: TEST: round float64[7,1] - # CHECK: mhlo.round + # CHECK: hlo.round # CHECK-SAME: tensor<7x1xf64> print_ir(np.empty((7,1), np.float64))( partial(lax.round, rounding_method=lax.RoundingMethod.AWAY_FROM_ZERO)) # CHECK-LABEL: TEST: rsqrt complex64[] - # CHECK: mhlo.rsqrt + # CHECK: hlo.rsqrt # CHECK-SAME: tensor> print_ir(jnp.complex64(0))(lax.rsqrt) # CHECK-LABEL: TEST: shift_left uint32[] uint32[] - # CHECK: mhlo.shift_left + # CHECK: hlo.shift_left # CHECK-SAME: tensor print_ir(np.uint32(0), np.uint32(0))(lax.shift_left) # CHECK-LABEL: TEST: shift_right_arithmetic uint8[] uint8[] - # CHECK: mhlo.shift_right_arithmetic + # CHECK: hlo.shift_right_arithmetic # CHECK-SAME: tensor print_ir(np.uint8(0), np.uint8(0))(lax.shift_right_arithmetic) # CHECK-LABEL: TEST: shift_right_logical uint16[] uint16[] - # CHECK: mhlo.shift_right_logical + # CHECK: hlo.shift_right_logical # CHECK-SAME: tensor print_ir(np.uint16(0), np.uint16(0))(lax.shift_right_logical) # CHECK-LABEL: TEST: sign int64[] - # CHECK: mhlo.sign + # CHECK: hlo.sign # CHECK-SAME: tensor print_ir(np.int64(0))(lax.sign) # CHECK-LABEL: TEST: sign uint32[] - # CHECK: mhlo.compare + # CHECK: hlo.compare # CHECK-SAME: tensor print_ir(np.uint32(0))(lax.sign) # CHECK-LABEL: TEST: sin float32[] - # CHECK: mhlo.sin + # CHECK: hlo.sin # CHECK-SAME: tensor print_ir(np.float32(0))(lax.sin) @@ -439,12 +439,12 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(np.float32(0))(lax.sinh) # CHECK-LABEL: TEST: sub float32[] float32[] - # CHECK: mhlo.sub + # CHECK: hlo.sub # CHECK-SAME: tensor print_ir(np.float32(1), np.float32(2))(lax.sub) # CHECK-LABEL: TEST: sqrt bfloat16[] - # CHECK: mhlo.sqrt + # CHECK: hlo.sqrt # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.sqrt) @@ -454,7 +454,7 @@ def integer_pow(x): return lax.integer_pow(x, 3) print_ir(np.float16(0))(lax.tan) # CHECK-LABEL: TEST: tanh float32[] - # CHECK: mhlo.tanh + # CHECK: hlo.tanh # CHECK-SAME: tensor print_ir(np.float32(0))(lax.tanh) diff --git a/tests/filecheck/shapes.filecheck.py b/tests/filecheck/shapes.filecheck.py index be610e8e7c40..5cd8feb864e1 100644 --- a/tests/filecheck/shapes.filecheck.py +++ b/tests/filecheck/shapes.filecheck.py @@ -30,77 +30,77 @@ def main(_): # CHECK-LABEL: TEST: bitwise_not bool[7] - # CHECK: mhlo.not + # CHECK: hlo.not # CHECK-SAME: tensor<7xi1> print_ir(np.empty([7], np.bool_))(lax.bitwise_not) # CHECK-LABEL: TEST: neg int8[] - # CHECK: mhlo.negate + # CHECK: hlo.negate # CHECK-SAME: tensor print_ir(np.int8(0))(lax.neg) # CHECK-LABEL: TEST: neg int16[0] - # CHECK: mhlo.negate + # CHECK: hlo.negate # CHECK-SAME: tensor<0xi16> print_ir(np.empty([0], np.int16))(lax.neg) # CHECK-LABEL: TEST: neg int32[2,3] - # CHECK: mhlo.negate + # CHECK: hlo.negate # CHECK-SAME: tensor<2x3xi32> print_ir(np.empty([2, 3], np.int32))(lax.neg) # CHECK-LABEL: TEST: neg int64[2,3,4] - # CHECK: mhlo.negate + # CHECK: hlo.negate # CHECK-SAME: tensor<2x3x4xi64> print_ir(np.empty([2,3,4], np.int64))(lax.neg) # CHECK-LABEL: TEST: add uint8[4,0,1] uint8[4,0,1] - # CHECK: mhlo.add + # CHECK: hlo.add # CHECK-SAME: tensor<4x0x1xui8> print_ir(np.empty([4,0,1], np.uint8), np.empty([4,0,1], np.uint8))(lax.add) # CHECK-LABEL: TEST: add uint16[] uint16[] - # CHECK: mhlo.add + # CHECK: hlo.add # CHECK-SAME: tensor print_ir(np.uint16(0), np.uint16(0))(lax.add) # CHECK-LABEL: TEST: add uint32[] uint32[] - # CHECK: mhlo.add + # CHECK: hlo.add # CHECK-SAME: tensor print_ir(np.uint32(0), np.uint32(0))(lax.add) # CHECK-LABEL: TEST: add uint64[] uint64[] - # CHECK: mhlo.add + # CHECK: hlo.add # CHECK-SAME: tensor print_ir(np.uint64(0), np.uint64(0))(lax.add) # CHECK-LABEL: TEST: sin float16[] - # CHECK: mhlo.sine + # CHECK: hlo.sine # CHECK-SAME: tensor print_ir(np.float16(0))(lax.sin) # CHECK-LABEL: TEST: sin bfloat16[] - # CHECK: mhlo.sine + # CHECK: hlo.sine # CHECK-SAME: tensor print_ir(jnp.bfloat16(0))(lax.sin) # CHECK-LABEL: TEST: sin float32[] - # CHECK: mhlo.sine + # CHECK: hlo.sine # CHECK-SAME: tensor print_ir(np.float32(0))(lax.sin) # CHECK-LABEL: TEST: sin float64[] - # CHECK: mhlo.sine + # CHECK: hlo.sine # CHECK-SAME: tensor print_ir(np.float64(0))(lax.sin) # CHECK-LABEL: TEST: cos complex64[] - # CHECK: mhlo.cosine + # CHECK: hlo.cosine # CHECK-SAME: tensor> print_ir(np.complex64(0))(lax.cos) # CHECK-LABEL: TEST: cos complex128[] - # CHECK: mhlo.cosine + # CHECK: hlo.cosine # CHECK-SAME: tensor> print_ir(np.complex128(0))(lax.cos) diff --git a/tests/filecheck/subcomputations.filecheck.py b/tests/filecheck/subcomputations.filecheck.py index 2ddc8d371cdc..155c75f2482f 100644 --- a/tests/filecheck/subcomputations.filecheck.py +++ b/tests/filecheck/subcomputations.filecheck.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -# Tests for lowering of array origami ops into MHLO. +# Tests for lowering of array origami ops into MLIR. # RUN: %PYTHON %s | FileCheck %s @@ -65,7 +65,7 @@ def g(x): with m1.context: # Reparse m2 in m1's context. m2_copy = ir.Module.parse(m2) - mlir.merge_mhlo_modules(m1, "m2_main_renamed", m2_copy) + mlir.merge_mlir_modules(m1, "m2_main_renamed", m2_copy) print("\nTEST: merge_modules") print(str(m1)) diff --git a/tests/jaxpr_effects_test.py b/tests/jaxpr_effects_test.py index f7475dd161f8..23d4585a2d43 100644 --- a/tests/jaxpr_effects_test.py +++ b/tests/jaxpr_effects_test.py @@ -370,43 +370,43 @@ def effect_lowering(ctx, *, effect): def f(x): effect_p.bind(effect='foo') return x + 1. - mhlo = f.lower(2.).compiler_ir() - main = mhlo.body.operations[0] + module = f.lower(2.).compiler_ir() + main = module.body.operations[0] first_op = main.body.blocks[0].operations[0] - self.assertEqual(first_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', first_op.operation.name) @jax.jit def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='foo2') return x + 1. - mhlo = f.lower(2.).compiler_ir() - main = mhlo.body.operations[0] + module = f.lower(2.).compiler_ir() + main = module.body.operations[0] first_op = main.body.blocks[0].operations[0] - self.assertEqual(first_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', first_op.operation.name) second_op = main.body.blocks[0].operations[1] - self.assertEqual(second_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', second_op.operation.name) @jax.jit def f(x): effect_p.bind(effect='foo') return x + 1. - mhlo = f.lower(2.).compiler_ir() - main = mhlo.body.operations[0] + module = f.lower(2.).compiler_ir() + main = module.body.operations[0] first_op = main.body.blocks[0].operations[0] - self.assertEqual(first_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', first_op.operation.name) @jax.jit def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='foo2') return x + 1. - mhlo = f.lower(2.).compiler_ir() - main = mhlo.body.operations[0] + module = f.lower(2.).compiler_ir() + main = module.body.operations[0] first_op = main.body.blocks[0].operations[0] - self.assertEqual(first_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', first_op.operation.name) second_op = main.body.blocks[0].operations[1] - self.assertEqual(second_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', second_op.operation.name) def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self): @@ -416,19 +416,18 @@ def test_nontrivial_lowering_with_ordered_effect_should_consume_token(self): def f(x): effect_p.bind(effect='foo') return x + 1. - - mhlo = f.lower(2.).compiler_ir() - main = mhlo.body.operations[0] + module = f.lower(2.).compiler_ir() + main = module.body.operations[0] first_op = main.body.blocks[0].operations[0] - self.assertEqual(first_op.operation.name, "mhlo.create_token") + self.assertIn('hlo.create_token', first_op.operation.name) second_op = main.body.blocks[0].operations[1] self.assertEqual(second_op.operation.name, "func.call") self.assertEqual(str(second_op.attributes["callee"]), "@effect") self.assertEqual(second_op.operands[0].owner, first_op) - func = mhlo.body.operations[1] + func = module.body.operations[1] self.assertEqual(func.name.value, "effect") - self.assertEqual(str(func.type.inputs[0]), "!mhlo.token") - self.assertEqual(str(func.type.results[0]), "!mhlo.token") + self.assertIn('hlo.token', str(func.type.inputs[0])) + self.assertIn('hlo.token', str(func.type.results[0])) def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self): @@ -438,14 +437,13 @@ def test_nontrivial_lowering_with_unordered_effect_should_consume_token(self): def f(x): effect_p.bind(effect='bar') return x + 1. - - mhlo = f.lower(2.).compiler_ir() - main = mhlo.body.operations[0] + module = f.lower(2.).compiler_ir() + main = module.body.operations[0] first_op = main.body.blocks[0].operations[0] self.assertEqual(first_op.operation.name, "func.call") self.assertEqual(str(first_op.attributes["callee"]), "@effect") self.assertLen(list(first_op.operands), 0) - func = mhlo.body.operations[1] + func = module.body.operations[1] self.assertEqual(func.name.value, "effect") self.assertLen(list(func.type.inputs), 0) self.assertLen(list(func.type.results), 0) @@ -455,13 +453,13 @@ def test_lowered_jaxpr_without_ordered_effects_takes_no_dummy_inputs(self): def f(x): effect_p.bind(effect='bar') return x + 1. - mhlo = f.lower(1.).compiler_ir(dialect='mhlo') - input_types = mhlo.body.operations[0].type.inputs + module = f.lower(1.).compiler_ir() + input_types = module.body.operations[0].type.inputs self.assertLen(list(input_types), 1) self.assertEqual(str(input_types[0]), 'tensor') # First output should be output token - result_types = mhlo.body.operations[0].type.results + result_types = module.body.operations[0].type.results if not can_execute_with_token: self.assertLen(list(result_types), 2) self.assertEqual(str(result_types[0]), 'tensor<0xi1>') @@ -476,14 +474,14 @@ def test_lowered_jaxpr_with_ordered_effects_takes_in_dummy_inputs(self): def f(x): effect_p.bind(effect='foo') return x + 1. - mhlo = f.lower(1.).compiler_ir(dialect='mhlo') - input_types = mhlo.body.operations[0].type.inputs + module = f.lower(1.).compiler_ir() + input_types = module.body.operations[0].type.inputs # First argument should be dummy token self.assertLen(list(input_types), 2) self.assertEqual(str(input_types[0]), 'tensor<0xi1>') # First output should be dummy token - result_types = mhlo.body.operations[0].type.results + result_types = module.body.operations[0].type.results self.assertLen(list(result_types), 2) self.assertEqual(str(result_types[0]), 'tensor<0xi1>') @@ -493,15 +491,15 @@ def f(x): effect_p.bind(effect='foo') effect_p.bind(effect='foo2') return x + 1. - mhlo = f.lower(1.).compiler_ir(dialect='mhlo') - input_types = mhlo.body.operations[0].type.inputs + module = f.lower(1.).compiler_ir() + input_types = module.body.operations[0].type.inputs # First two arguments should be dummy values self.assertLen(list(input_types), 3) self.assertEqual(str(input_types[0]), 'tensor<0xi1>') self.assertEqual(str(input_types[1]), 'tensor<0xi1>') # First two outputs should be dummy values - result_types = mhlo.body.operations[0].type.results + result_types = module.body.operations[0].type.results self.assertLen(list(result_types), 3) self.assertEqual(str(result_types[0]), 'tensor<0xi1>') self.assertEqual(str(result_types[1]), 'tensor<0xi1>') diff --git a/tests/lax_test.py b/tests/lax_test.py index 5b62a508c2f1..e95aafd9c0e8 100644 --- a/tests/lax_test.py +++ b/tests/lax_test.py @@ -40,7 +40,7 @@ from jax.interpreters import batching from jax.interpreters import pxla from jax._src import array -from jax._src.lib.mlir.dialects import mhlo +from jax._src.lib.mlir.dialects import hlo from jax._src import dispatch from jax._src import dtypes from jax._src import test_util as jtu @@ -2867,17 +2867,17 @@ def slice_mlir(ctx, aval_out, x, start_indices, limit_indices, strides): start_indices = (*start_indices, 0) limit_indices = (*limit_indices, 2) strides = (*strides, 1) - return mhlo.SliceOp(x, - mlir.dense_int_elements(start_indices), - mlir.dense_int_elements(limit_indices), - mlir.dense_int_elements(strides)).result + return hlo.SliceOp(x, + mlir.dense_int_elements(start_indices), + mlir.dense_int_elements(limit_indices), + mlir.dense_int_elements(strides)).result @staticmethod def dynamic_slice_mlir(ctx, aval_out, x, start_indices): dtype = dtypes.canonicalize_dtype(np.dtype('int64')) start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype))) slice_sizes_ = mlir.dense_int_elements((*aval_out.shape, 2)) - return mhlo.DynamicSliceOp(x, start_indices, slice_sizes_).result + return hlo.DynamicSliceOp(x, start_indices, slice_sizes_).result @staticmethod def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices): @@ -2885,22 +2885,22 @@ def dynamic_update_slice_mlir(ctx, aval_out, x, update, *start_indices): dtype = dtypes.canonicalize_dtype(np.dtype('int64')) start_indices = (*start_indices, mlir.ir_constant(np.array(0, dtype=dtype))) if xc.mlir_api_version < 40: - return mhlo.DynamicUpdateSliceOp( + return hlo.DynamicUpdateSliceOp( mlir.aval_to_ir_type(aval_out), x, update, start_indices).result else: - return mhlo.DynamicUpdateSliceOp(x, update, start_indices).result + return hlo.DynamicUpdateSliceOp(x, update, start_indices).result @staticmethod def broadcast_in_dim_mlir(ctx, aval_out, x, broadcast_dimensions): broadcast_dimensions = [*broadcast_dimensions, aval_out.ndim] - return mhlo.BroadcastInDimOp( + return hlo.BroadcastInDimOp( mlir.aval_to_ir_type(aval_out), x, mlir.dense_int_elements(broadcast_dimensions)).result @staticmethod def transpose_mlir(ctx, aval_out, x, *, permutation): perm = [*permutation, len(permutation)] - return mhlo.TransposeOp(x, mlir.dense_int_elements(perm)).result + return hlo.TransposeOp(x, mlir.dense_int_elements(perm)).result @staticmethod def gather_mlir(ctx, avals_in, aval_out, x, indices, *, diff --git a/tests/pjit_test.py b/tests/pjit_test.py index b9363693f82a..8da789eac40c 100644 --- a/tests/pjit_test.py +++ b/tests/pjit_test.py @@ -501,9 +501,9 @@ def f(x): self.assertAllClose(actual, expected, check_dtypes=False) self.assertLen(actual[0]['a'].device_buffers, 4) - mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo")) - self.assertIn("unspecified_dims=[0]", mhlo_str) - self.assertIn("unspecified_dims=[1]", mhlo_str) + mlir_str = str(f.lower(x).compiler_ir()) + self.assertIn("unspecified_dims=[0]", mlir_str) + self.assertIn("unspecified_dims=[1]", mlir_str) @jtu.with_mesh([('x', 2), ('y', 2)]) def testShardingConstraintPyTreeVmapWithUnconstrainedDims(self): @@ -521,9 +521,9 @@ def f(x): v = np.arange(prod(shape)).reshape(shape) x = [{'a': v, 'b': v * 2}, v * 3] - mhlo_str = str(f.lower(x).compiler_ir(dialect="mhlo")) - self.assertIn("unspecified_dims=[0,1]", mhlo_str) - self.assertIn("unspecified_dims=[0,2]", mhlo_str) + mlir_str = str(f.lower(x).compiler_ir()) + self.assertIn("unspecified_dims=[0,1]", mlir_str) + self.assertIn("unspecified_dims=[0,2]", mlir_str) def testCaching(self): def f(x):