Skip to content

Commit

Permalink
(NFC) Prepare for migration from producing MHLO to producing StableHLO
Browse files Browse the repository at this point in the history
This CL renames occurrences of "mhlo" in: 1) names, 2) tests, 3) prose in order
to prepare for the upcoming migration.

Unchanged occurrences:
  1) Public API that contains "mhlo", e.g. XlaLowering.mhlo and the "mhlo"
     argument value in Lowering.as_text and Lowering.compiler_ir.
  2) Documentation (changelog, JEPs, IR examples, etc).
  3) One rare situation where prose says "StableHLO" and "MHLO" in one sentence,
     so both are necessary to disambiguate.

PiperOrigin-RevId: 495771153
  • Loading branch information
Eugene Burmako authored and jax authors committed Dec 16, 2022
1 parent 523c6f7 commit b8ae8e3
Show file tree
Hide file tree
Showing 49 changed files with 991 additions and 882 deletions.
2 changes: 1 addition & 1 deletion build/build_wheel.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def prepare_wheel(sources_path):
copy_to_jaxlib("__main__/jaxlib/init.py", dst_filename="__init__.py")
copy_to_jaxlib(f"__main__/jaxlib/cpu_feature_guard.{pyext}")
copy_to_jaxlib("__main__/jaxlib/lapack.py")
copy_to_jaxlib("__main__/jaxlib/mhlo_helpers.py")
copy_to_jaxlib("__main__/jaxlib/hlo_helpers.py")
copy_to_jaxlib("__main__/jaxlib/ducc_fft.py")
copy_to_jaxlib("__main__/jaxlib/gpu_prng.py")
copy_to_jaxlib("__main__/jaxlib/gpu_linalg.py")
Expand Down
6 changes: 3 additions & 3 deletions jax/_src/ad_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
from jax._src.lax import lax as lax_internal
from jax._src.lax import convolution as lax_convolution
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.traceback_util import api_boundary
from jax._src.util import (unzip2, wraps, split_list, partition_list, safe_map,
safe_zip, merge_lists, weakref_lru_cache)
Expand Down Expand Up @@ -623,9 +623,9 @@ def _optimization_barrier_lowering_rule(ctx, *args):
flat_barrier_types = util.flatten(barrier_types)
flat_args = mlir.flatten_lowering_ir_args(args)
if xc.mlir_api_version < 40:
barrier_op = mhlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
barrier_op = hlo.OptimizationBarrierOp(flat_barrier_types, flat_args)
else:
barrier_op = mhlo.OptimizationBarrierOp(flat_args)
barrier_op = hlo.OptimizationBarrierOp(flat_args)
return util.unflatten(barrier_op.results, map(len, barrier_types))

def _optimization_barrier(arg):
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/custom_derivatives.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,8 +1205,8 @@ def unreachable_impl(*_, out_avals, exc_type, message):

# Translation raises an exception
# TODO(frostig,mattjj): We have no good way to translate a function
# that errs. Since MHLO lowering over-approximates concrete evaluation,
# we err on MHLO lowering for the time being.
# that errs. Since MLIR lowering over-approximates concrete evaluation,
# we err on MLIR lowering for the time being.
mlir.register_lowering(unreachable_p, unreachable_impl)

# Abstract evaluation proceeds without issue, to allow for staging
Expand Down
20 changes: 10 additions & 10 deletions jax/_src/debugging.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from jax._src.lax import control_flow as lcf
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
import jax.numpy as jnp

import numpy as np
Expand Down Expand Up @@ -335,15 +335,15 @@ def _hlo_sharding_callback(hlo_sharding):
# partitioner runs so we keep it alive by attaching it to the executable.
ctx.module_context.add_keepalive(sharding_callback_info)

mhlo.CustomCallOp([value.type], [value],
call_target_name=ir.StringAttr.get(
_INSPECT_SHARDING_CALL_NAME),
has_side_effect=ir.BoolAttr.get(True),
api_version=mlir.i32_attr(1),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(key),
operand_layouts=None,
result_layouts=None)
hlo.CustomCallOp([value.type], [value],
call_target_name=ir.StringAttr.get(
_INSPECT_SHARDING_CALL_NAME),
has_side_effect=ir.BoolAttr.get(True),
api_version=mlir.i32_attr(1),
called_computations=ir.ArrayAttr.get([]),
backend_config=ir.StringAttr.get(key),
operand_layouts=None,
result_layouts=None)
return []
mlir.register_lowering(inspect_sharding_p, _inspect_sharding_lowering_rule)

Expand Down
2 changes: 1 addition & 1 deletion jax/_src/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@

flags.DEFINE_string(
'jax_dump_ir_to', os.getenv('JAX_DUMP_IR_TO', ''),
help="Path to which HLO/MHLO IR that is emitted by JAX as input to the "
help="Path to which the IR that is emitted by JAX as input to the "
"compiler should be dumped as text files. Optional. If omitted, JAX "
"will not dump IR.")

Expand Down
10 changes: 5 additions & 5 deletions jax/_src/lax/control_flow/conditionals.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
from jax._src.util import (safe_map, extend_name_stack, split_list,
partition_list)
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
import numpy as np

from jax._src.lax.control_flow.common import (
Expand Down Expand Up @@ -806,11 +806,11 @@ def _cond_lowering(ctx, index, *args, branches, linear):
*output_token_types, *map(mlir.aval_to_ir_types, ctx.avals_out)]
flat_output_types = util.flatten(output_types)

# mhlo.CaseOp takes a single argument 'index' and the corresponding blocks
# CaseOp takes a single argument 'index' and the corresponding blocks
# have no arguments; the computation within the block uses implicit
# captures.
case_op = mhlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
case_op = hlo.CaseOp(flat_output_types, index=index,
num_branches=len(branches))
name_stack = extend_name_stack(ctx.module_context.name_stack, 'cond')
for i, jaxpr in enumerate(branches):
branch = case_op.regions[i].blocks.append()
Expand All @@ -824,7 +824,7 @@ def _cond_lowering(ctx, index, *args, branches, linear):
dim_var_values=ctx.dim_var_values)
out_tokens = [tokens_out.get(eff) for eff in ordered_effects]
out_vals = [*out_tokens, *out_vals]
mhlo.ReturnOp(util.flatten(out_vals))
hlo.ReturnOp(util.flatten(out_vals))

tokens_and_outputs = util.unflatten(case_op.results, map(len, output_types))
tokens, outputs = util.split_list(tokens_and_outputs, [num_tokens])
Expand Down
28 changes: 14 additions & 14 deletions jax/_src/lax/control_flow/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from jax._src.lax import windowed_reductions
from jax._src.lib import xla_client as xc
from jax._src.lib.mlir import ir
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.numpy.ufuncs import logaddexp
from jax._src.traceback_util import api_boundary
from jax._src.util import (
Expand Down Expand Up @@ -146,7 +146,7 @@ def scan(f, init, xs, length=None):
output arrays. (None is actually an empty pytree.)
Also unlike that Python version, :func:`~scan` is a JAX primitive and is
lowered to a single XLA While HLO. That makes it useful for reducing
lowered to a single WhileOp. That makes it useful for reducing
compilation times for JIT-compiled functions, since native Python
loop constructs in an :func:`~jax.jit` function are unrolled, leading to large
XLA computations.
Expand Down Expand Up @@ -1041,7 +1041,7 @@ def while_loop(cond_fun, body_fun, init_val):
return val
Unlike that Python version, ``while_loop`` is a JAX primitive and is lowered
to a single XLA While HLO. That makes it useful for reducing compilation times
to a single WhileOp. That makes it useful for reducing compilation times
for jit-compiled functions, since native Python loop constructs in an ``@jit``
function are unrolled, leading to large XLA computations.
Expand Down Expand Up @@ -1420,7 +1420,7 @@ def _while_transpose_error(*_, **kwargs):
# break
# token, x = body(token, x)
# ```
# Unfortunately, with an MHLO while we can't (1) return multiple values
# Unfortunately, with a WhileOp we can't (1) return multiple values
# from a `cond` and (2) can't break a while loop. We thus adopt the
# following rewrite strategy:
# ```
Expand Down Expand Up @@ -1471,7 +1471,7 @@ def fun(*args):
args = [*tokens, *args]

flat_args = mlir.flatten_lowering_ir_args(args)
while_op = mhlo.WhileOp(flat_loop_carry_types, flat_args)
while_op = hlo.WhileOp(flat_loop_carry_types, flat_args)

# Loop condition
cond_block = while_op.regions[0].blocks.append(*flat_loop_carry_types)
Expand All @@ -1498,12 +1498,12 @@ def fun(*args):
tokens_in=mlir.TokenSet(),
tokens_out=None)
pred, = lax._unary_reduce_lower(
mhlo.OrOp,
hlo.OrOp,
lambda dtype: np.array(False, dtype),
pred_ctx,
pred,
axes=tuple(range(len(pred_aval.shape))))
mhlo.ReturnOp([pred])
hlo.ReturnOp([pred])

# Loop body
body_block = while_op.regions[1].blocks.append(*flat_loop_carry_types)
Expand Down Expand Up @@ -1531,11 +1531,11 @@ def fun(*args):
_map(mlir.ir_constants, cond_jaxpr.consts),
*(x + z), dim_var_values=ctx.dim_var_values)
new_z = _map(
partial(_pred_bcast_select_mhlo, ctx, pred_aval, body_pred), new_z, z,
partial(_pred_bcast_select_hlo, ctx, pred_aval, body_pred), new_z, z,
body_jaxpr.out_avals)

mhlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])
hlo.ReturnOp([*util.flatten(out_tokens), *util.flatten(x),
*util.flatten(y), *util.flatten(new_z)])

outputs = util.unflatten(while_op.results, _map(len, loop_carry_types))
tokens, _, _, z = util.split_list(outputs, [num_tokens, cond_nconsts, body_nconsts])
Expand Down Expand Up @@ -1566,16 +1566,16 @@ def _while_typecheck(*in_atoms, cond_jaxpr, body_jaxpr, cond_nconsts,
core.custom_typechecks[while_p] = _while_typecheck


def _pred_bcast_select_mhlo(ctx,
def _pred_bcast_select_hlo(ctx,
pred_aval: core.ShapedArray, pred: ir.Value, xs: Sequence[ir.Value],
ys: Sequence[ir.Value], x_y_aval: core.AbstractValue) -> Sequence[ir.Value]:
if x_y_aval is core.abstract_token:
x, = xs
y, = ys
if xc.mlir_api_version < 40:
return [mhlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
return [hlo.AfterAllOp(mlir.aval_to_ir_type(x_y_aval), [x, y]).result]
else:
return [mhlo.AfterAllOp([x, y]).result]
return [hlo.AfterAllOp([x, y]).result]
else:
assert isinstance(x_y_aval, core.ShapedArray), x_y_aval
x, = xs
Expand All @@ -1589,7 +1589,7 @@ def _pred_bcast_select_mhlo(ctx,
x_y_shape = x_y_aval.shape
bcast_pred = mlir.broadcast_in_dim(ctx, pred, core.DShapedArray(x_y_shape, np.dtype(np.bool_)),
broadcast_dimensions=list(range(len(pred_aval.shape))))
return mhlo.SelectOp(bcast_pred, x, y).results
return hlo.SelectOp(bcast_pred, x, y).results

### fori_loop

Expand Down
6 changes: 3 additions & 3 deletions jax/_src/lax/convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from jax.interpreters import batching
from jax.interpreters import mlir
from jax._src import util
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client

_max = builtins.max
Expand Down Expand Up @@ -688,7 +688,7 @@ def _conv_general_dilated_lower(
return complex_conv(ctx, lhs, rhs)

lhs_spec, rhs_spec, out_spec = dimension_numbers
dnums = mhlo.ConvDimensionNumbers.get(
dnums = hlo.ConvDimensionNumbers.get(
input_batch_dimension=lhs_spec[0],
input_feature_dimension=lhs_spec[1],
input_spatial_dimensions=list(lhs_spec[2:]),
Expand All @@ -703,7 +703,7 @@ def _conv_general_dilated_lower(
padding = np.zeros((0, 2), dtype=np.int64)
window_reversal = mlir.dense_bool_elements([False] * num_spatial_dims)
return [
mhlo.ConvolutionOp(
hlo.ConvolutionOp(
mlir.aval_to_ir_type(aval_out),
lhs,
rhs,
Expand Down
12 changes: 8 additions & 4 deletions jax/_src/lax/fft.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from jax import lax
from jax.interpreters import ad
from jax.interpreters import batching
from jax._src.lib.mlir.dialects import mhlo
from jax._src.lib.mlir.dialects import hlo
from jax._src.lib import xla_client
from jax._src.lib import ducc_fft
from jax._src.numpy.util import _promote_dtypes_complex, _promote_dtypes_inexact
Expand Down Expand Up @@ -104,16 +104,20 @@ def fft_abstract_eval(x, fft_type, fft_lengths):

def _fft_lowering(ctx, x, *, fft_type, fft_lengths):
return [
mhlo.FftOp(x, mhlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
hlo.FftOp(x, hlo.FftTypeAttr.get(fft_type.name),
mlir.dense_int_elements(fft_lengths)).result
]


def _fft_lowering_cpu(ctx, x, *, fft_type, fft_lengths):
if any(not is_constant_shape(a.shape) for a in (ctx.avals_in + ctx.avals_out)):
raise NotImplementedError("Shape polymorphism for custom call is not implemented (fft); b/261671778")
x_aval, = ctx.avals_in
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
if xla_client.mlir_api_version < 41:
return [ducc_fft.ducc_fft_mhlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]
else:
return [ducc_fft.ducc_fft_hlo(x, x_aval.dtype, fft_type=fft_type,
fft_lengths=fft_lengths)]

def _naive_rfft(x, fft_lengths):
Expand Down
Loading

0 comments on commit b8ae8e3

Please sign in to comment.