Skip to content

Commit

Permalink
mhlo: migrate conversion to stablehlo (#1840)
Browse files Browse the repository at this point in the history
This patch replaces all MHLO operations with their StableHLO
counterparts and adds a validation pass to ensure that no MHLO operations
remain before translating all Stablehlo operations to the MHLO dialect
for further lowering to the Linalg dialect.

This patch also updates all lit tests so that they refer to the
`convert-torch-to-stablehlo` pass and so that they check for StableHLO
operations.
  • Loading branch information
ashay authored Feb 2, 2023
1 parent ed9d8d1 commit 711646d
Show file tree
Hide file tree
Showing 55 changed files with 1,190 additions and 1,136 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ jobs:
-DLLVM_USE_HOST_TOOLS=ON \
-DLLVM_ENABLE_ZSTD=OFF \
-DMLIR_ENABLE_BINDINGS_PYTHON=ON \
-DTORCH_MLIR_ENABLE_MHLO=OFF \
-DTORCH_MLIR_ENABLE_STABLEHLO=OFF \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \
-DMACOSX_DEPLOYMENT_TARGET=12.0 \
Expand Down
10 changes: 5 additions & 5 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ macro(torch_mlir_add_llvm_external_project name identifier location)
set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE)
endmacro()

option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON)
if(TORCH_MLIR_ENABLE_MHLO)
add_definitions(-DTORCH_MLIR_ENABLE_MHLO)
option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON)
if(TORCH_MLIR_ENABLE_STABLEHLO)
add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO)
endif()

option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF)
Expand Down Expand Up @@ -128,8 +128,8 @@ else()
set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}")
endif()

if (TORCH_MLIR_ENABLE_MHLO)
set(MHLO_BUILD_EMBEDDED ON)
if (TORCH_MLIR_ENABLE_STABLEHLO)
set(STABLEHLO_BUILD_EMBEDDED ON)
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo
${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo
EXCLUDE_FROM_ALL)
Expand Down
4 changes: 2 additions & 2 deletions build_tools/python_deploy/build_linux_packages.sh
Original file line number Diff line number Diff line change
Expand Up @@ -267,8 +267,8 @@ function test_in_tree() {
echo ":::: Run Linalg e2e integration tests"
python -m e2e_testing.main --config=linalg -v

echo ":::: Run MHLO e2e integration tests"
python -m e2e_testing.main --config=mhlo -v
echo ":::: Run StableHLO e2e integration tests"
python -m e2e_testing.main --config=stablehlo -v

echo ":::: Run TOSA e2e integration tests"
python -m e2e_testing.main --config=tosa -v
Expand Down
27 changes: 14 additions & 13 deletions docs/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various

- Linalg-on-Tensors (+ `arith`, `tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo)
- [StableHLO](https://github.com/openxla/stablehlo)

The terms "frontend" and "backend" are highly overloaded in any compiler
project, but frequently in Torch-MLIR this is the meaning that they have.
Sometimes "frontend" can mean something even further up the stack, such as
something in PyTorch itself. When there is ambiguity we will refer to this as
"at the PyTorch level". Similarly, "backend" can sometimes refer to something
sitting below Linalg-on-Tensors, TOSA, or MHLO.
sitting below Linalg-on-Tensors, TOSA, or StableHLO.

## The `torch` dialect

Expand Down Expand Up @@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96

The backend contract is a normalized form of the `torch` dialect with a set of
properties that make it easy to lower into various forms such as
Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the
box. The primary guarantees that we provide Torch-MLIR's backends are:
Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of
the box. The primary guarantees that we provide Torch-MLIR's backends are:

- All tensors have been converted to value semantics.
- All tensors have at least a known number of dimensions (i.e. rank), and
Expand Down Expand Up @@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are:
- [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`,
`tensor`, etc.)
- [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/)
- [MHLO](https://github.com/tensorflow/mlir-hlo)
- [StableHLO](https://github.com/openxla/stablehlo)

### The Linalg Backend (Linalg-on-Tensors)

Expand All @@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha
- It is extremely solid with static shapes (and many of its users only care
about static shapes, so that's fine).

### The MHLO Backend
### The StableHLO Backend

Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo
Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo

The MHLO backend was the third backend that we added, and it offers a reasonable
blend of the benefits of the other two.
The StableHLO backend was the third backend that we added, and it offers a
reasonable blend of the benefits of the other two.
- It is a coarse-grained named-op approach.
- It has a pretty clear spec for most of the ops (with a bit of mental
translation and hoping that MHLO is the same as HLO):
translation and hoping that StableHLO is the same as HLO):
https://www.tensorflow.org/xla/operation_semantics
- It functionally supports dynamic shapes (though not as coherent and consistent
as Linalg-on-Tensors, and the dynamic shape support falls outside the
Expand All @@ -317,7 +317,7 @@ blend of the benefits of the other two.
example, TOSA limits (for highly considered reasons) the number of dimensions
that certain operators can handle to 1D-4D, when from a purely algebraic
perspective there isn't a good reason to not be more general. Similarly, more
general forms of reduction and scatter also fall into MHLO nicely while
general forms of reduction and scatter also fall into StableHLO nicely while
TOSA's principles tend to bias it away from that.

### Backend Implementation
Expand Down Expand Up @@ -433,8 +433,9 @@ filling in some corners missing upstream and
to pull together upstream functionality into a working system.

The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the
ops and lowers them to loops. Note that TOSA and MHLO support lowering to
Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend.
ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support
lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on
RefBackend.

The RefBackend is absolutely not suitable for any production use case. It leaks
memory, doesn't support any error handling, performs no optimizations, and
Expand Down
2 changes: 1 addition & 1 deletion docs/code_owners.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ and Clang's
- Eric Kunze (@eric-k256)
- Suraj Sudhir (@sjarus)

### TorchToMHLO
### TorchToStablehlo

- Tianyo Kwok (@tanyokwok)
- Ziheng Jiang (@ZihengJiang)
Expand Down
2 changes: 1 addition & 1 deletion docs/development.md
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ Ex:
module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch")
```

Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`.
Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`.

## Jupyter

Expand Down
14 changes: 6 additions & 8 deletions docs/long_term_roadmap.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ the ecosystem are:

- The frontend work required to lower TorchScript to the backend contract.
- The irregular support surface area of the large number of PyTorch ops across
the Linalg, TOSA, and MHLO backends.
the Linalg, TOSA, and StableHLO backends.

Most of this document describes long-term ecosystem changes that will address
these, drastically improving Torch-MLIR's ability to meet its goals.
Expand Down Expand Up @@ -108,7 +108,7 @@ more advanced).
### Refactoring the backend

Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors,
TOSA, and MHLO. These backends take IR in the backend contract form (see
TOSA, and StableHLO. These backends take IR in the backend contract form (see
[architecture.md](architecture.md)) and lowers them to the respective dialects.
Today, each backend is implemented completely independently. This leads to
duplication and irregularity across the backends.
Expand All @@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3
forward-looking efforts that intersect with this effort:

- [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect
initially forked from MHLO which intends to create a stable support surface
area for what today is our "at head" dependency on MHLO. MHLO is a fairly
complete op set, so it is very attractive to have "almost all" models
bottleneck through a stable interface like StableHLO. StableHLO is currently
under relatively early development, but already delivers on many of the goals
of stability.
initially forked from MHLO. MHLO is a fairly complete op set, so it is very
attractive to have "almost all" models bottleneck through a stable interface
like StableHLO. StableHLO is currently under relatively early development,
but already delivers on many of the goals of stability.
- [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect
which could serve a role very similar to MHLO, while providing community
ownership. TCP is still in early planning phases, but there is strong
Expand Down
16 changes: 8 additions & 8 deletions e2e_testing/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,33 +16,33 @@
from torch_mlir_e2e_test.configs import (
LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
StablehloBackendTestConfig,
NativeTorchTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
TorchDynamoTestConfig,
)

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET
from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"]
config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"]
parser = argparse.ArgumentParser(description="Run torchscript e2e tests.")
parser.add_argument("-c", "--config",
choices=config_choices,
default="linalg",
help=f"""
Meaning of options:
"linalg": run through torch-mlir"s default Linalg-on-Tensors backend.
"mhlo": run through torch-mlir"s default MHLO backend.
"stablehlo": run through torch-mlir"s default StableHLO backend.
"tosa": run through torch-mlir"s default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -80,9 +80,9 @@ def main():
if args.config == "tosa":
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == "mhlo":
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
if args.config == "stablehlo":
config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend())
xfail_set = all_test_unique_names - STABLEHLO_PASS_SET
elif args.config == "native_torch":
config = NativeTorchTestConfig()
xfail_set = {}
Expand Down
15 changes: 14 additions & 1 deletion e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,10 @@
"StdCorrectionKeepDimModule_basic",
}

MHLO_PASS_SET = {
STABLEHLO_PASS_SET = {
"AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic",
"AddSizeIntModule_basic",
"AddSizeIntNegDimModule_basic",
"ArangeDtypeFloatModule_basic",
"ArangeDtypeIntModule_basic",
"ArangeFalsePinMemoryModule_basic",
Expand All @@ -103,6 +105,7 @@
"ArangeStartStepFloatModule_basic",
"ArangeStartStepIntModule_basic",
"ArangeZeroElementOutputModule_basic",
"BatchMlpLayerModule_basic",
"BmmModule_basic",
"BroadcastToModule_basic",
"BroadcastToSameRankStaticModule_basic",
Expand All @@ -124,12 +127,15 @@
"ElementwiseClampMinModule_basic",
"ElementwiseClampMaxModule_basic",
"ElementwiseExpModule_basic",
"ElementwiseFlattenBroadcastModule_basic",
"ElementwiseLeakyReluModule_basic",
"ElementwiseLogModule_basic",
"ElementwiseNegModule_basic",
"ElementwiseRsqrtModule_basic",
"ElementwiseSigmoidModule_basic",
"ElementwiseSqrtModule_basic",
"ElementwiseUnaryModule_basic",
"ElementwiseUnsqueezeBroadcastModule_basic",
"ElementwiseUnsqueezeNegDimsModule_basic",
"ElementwiseToDtypeF32ToI64Module_basic",
"ElementwiseAddModule_basic",
Expand Down Expand Up @@ -198,6 +204,8 @@
"Gather2DInputModdule_basic",
"GatherRandomIndexModule_basic",
"GeluBackwardModule_basic",
"HardswishModule_basic",
"HardswishRandomModule_basic",
"HardTanhIntModule_basic",
"HardTanhModule_basic",
"HardsigmoidModule_basic",
Expand All @@ -220,6 +228,8 @@
"MeanDynamicSizesModule_basic",
"MeanLargeInputModule_basic",
"MeanModule_basic",
"Mlp1LayerModule_basic",
"Mlp2LayerModule_basic",
"MmTanhModule_basic",
"Mv_basic",
"NativeLayerNormModule4D_basic",
Expand Down Expand Up @@ -251,6 +261,8 @@
"LiftFreshCopyModule_basic",
"Mlp2LayerModuleNoBias_basic",
"NumelModule_basic",
"SiluModule_basic",
"SquareModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"ViewCollapseOnesMiddleModule_basic",
Expand Down Expand Up @@ -420,6 +432,7 @@
"UnsafeViewDynamicExpandModule_basic",
"AtenRoundIntModule_basic",
"TestF16Return_basic",
"_LogSoftmaxModuleStable_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
Expand Down
14 changes: 0 additions & 14 deletions examples/torchscript_mhlo_backend_resnet.py

This file was deleted.

14 changes: 14 additions & 0 deletions examples/torchscript_stablehlo_backend_resnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
import torchvision.models as models
import torch_mlir

model = models.resnet18(pretrained=True)
model.eval()
data = torch.randn(2,3,200,200)
out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir"

module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}")
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ def forward(self, data):
model = BertTinyWrapper()
model.eval()
data = torch.randint(30522, (2, 128))
out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir"
out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir"

module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True)
with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf:
module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True)
with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf:
outf.write(str(module))

print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}")
print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}")
4 changes: 2 additions & 2 deletions include/torch-mlir/Conversion/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
set(LLVM_TARGET_DEFINITIONS Passes.td)
if(TORCH_MLIR_ENABLE_MHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO)
if(TORCH_MLIR_ENABLE_STABLEHLO)
mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO)
else()
mlir_tablegen(Passes.h.inc -gen-pass-decls)
endif()
Expand Down
10 changes: 5 additions & 5 deletions include/torch-mlir/Conversion/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog
let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()";
}

#ifdef TORCH_MLIR_ENABLE_MHLO
def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> {
let summary = "Convert Torch ops to MHLO ops";
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> {
let summary = "Convert Torch ops to Stablehlo ops";
let description = [{
Convert Torch ops to mhlo ops.
Convert Torch ops to Stablehlo ops.
}];
let constructor = "mlir::torch::createConvertTorchToMhloPass()";
let constructor = "mlir::torch::createConvertTorchToStablehloPass()";

// Specify any options.
let options = [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,20 @@
//
//===----------------------------------------------------------------------===//

#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H

#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Pass/Pass.h"
#include <memory>

namespace mlir {
namespace torch {
std::unique_ptr<OperationPass<func::FuncOp>> createConvertTorchToMhloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index);
createConvertTorchToStablehloPass();
std::unique_ptr<OperationPass<func::FuncOp>>
createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index);
} // namespace torch
} // namespace mlir

#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H
#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H
Loading

0 comments on commit 711646d

Please sign in to comment.