diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 8a71057fac3d..07530445474c 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -113,7 +113,7 @@ jobs: -DLLVM_USE_HOST_TOOLS=ON \ -DLLVM_ENABLE_ZSTD=OFF \ -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_MHLO=OFF \ + -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ -DTORCH_MLIR_ENABLE_LTC=OFF \ -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ -DMACOSX_DEPLOYMENT_TARGET=12.0 \ diff --git a/CMakeLists.txt b/CMakeLists.txt index c20627065d1f..a6d498013ef9 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,9 +36,9 @@ macro(torch_mlir_add_llvm_external_project name identifier location) set(LLVM_EXTERNAL_PROJECTS ${LLVM_EXTERNAL_PROJECTS} CACHE STRING "" FORCE) endmacro() -option(TORCH_MLIR_ENABLE_MHLO "Add mhlo dialect" ON) -if(TORCH_MLIR_ENABLE_MHLO) - add_definitions(-DTORCH_MLIR_ENABLE_MHLO) +option(TORCH_MLIR_ENABLE_STABLEHLO "Add stablehlo dialect" ON) +if(TORCH_MLIR_ENABLE_STABLEHLO) + add_definitions(-DTORCH_MLIR_ENABLE_STABLEHLO) endif() option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) @@ -128,8 +128,8 @@ else() set(MLIR_INCLUDE_DIRS "${MLIR_INCLUDE_DIR};${MLIR_GENERATED_INCLUDE_DIR}") endif() -if (TORCH_MLIR_ENABLE_MHLO) - set(MHLO_BUILD_EMBEDDED ON) +if (TORCH_MLIR_ENABLE_STABLEHLO) + set(STABLEHLO_BUILD_EMBEDDED ON) add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/externals/mlir-hlo ${CMAKE_CURRENT_BINARY_DIR}/mlir-hlo EXCLUDE_FROM_ALL) diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 4a3bc375ec4e..32435cb8616d 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -267,8 +267,8 @@ function test_in_tree() { echo ":::: Run Linalg e2e integration tests" python -m e2e_testing.main --config=linalg -v - echo ":::: Run MHLO e2e integration tests" - python -m e2e_testing.main --config=mhlo -v + echo ":::: Run StableHLO e2e integration tests" + python -m e2e_testing.main --config=stablehlo -v echo ":::: Run TOSA e2e integration tests" python -m e2e_testing.main --config=tosa -v diff --git a/docs/architecture.md b/docs/architecture.md index 1619f81a883e..043bb74ec21e 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -30,14 +30,14 @@ it to various target dialects of interest to the MLIR ecosystem (various - Linalg-on-Tensors (+ `arith`, `tensor`, etc.) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) -- [MHLO](https://github.com/tensorflow/mlir-hlo) +- [StableHLO](https://github.com/openxla/stablehlo) The terms "frontend" and "backend" are highly overloaded in any compiler project, but frequently in Torch-MLIR this is the meaning that they have. Sometimes "frontend" can mean something even further up the stack, such as something in PyTorch itself. When there is ambiguity we will refer to this as "at the PyTorch level". Similarly, "backend" can sometimes refer to something -sitting below Linalg-on-Tensors, TOSA, or MHLO. +sitting below Linalg-on-Tensors, TOSA, or StableHLO. ## The `torch` dialect @@ -118,8 +118,8 @@ See [satisfiesBackendContract](https://github.com/llvm/torch-mlir/blob/114f48e96 The backend contract is a normalized form of the `torch` dialect with a set of properties that make it easy to lower into various forms such as -Linalg-on-Tensors, TOSA, MHLO, or other forms that we don't provide out of the -box. The primary guarantees that we provide Torch-MLIR's backends are: +Linalg-on-Tensors, TOSA, StableHLO, or other forms that we don't provide out of +the box. The primary guarantees that we provide Torch-MLIR's backends are: - All tensors have been converted to value semantics. - All tensors have at least a known number of dimensions (i.e. rank), and @@ -270,7 +270,7 @@ lower it to the requirements of each backend. The 3 backends are: - [`linalg`](https://mlir.llvm.org/docs/Dialects/Linalg/) on tensors (+ `arith`, `tensor`, etc.) - [TOSA](https://mlir.llvm.org/docs/Dialects/TOSA/) -- [MHLO](https://github.com/tensorflow/mlir-hlo) +- [StableHLO](https://github.com/openxla/stablehlo) ### The Linalg Backend (Linalg-on-Tensors) @@ -297,15 +297,15 @@ many users (especially "hardware" or "hardware-adjacent" folks). Some of its cha - It is extremely solid with static shapes (and many of its users only care about static shapes, so that's fine). -### The MHLO Backend +### The StableHLO Backend -Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToMhlo +Code: https://github.com/llvm/torch-mlir/tree/main/lib/Conversion/TorchToStablehlo -The MHLO backend was the third backend that we added, and it offers a reasonable -blend of the benefits of the other two. +The StableHLO backend was the third backend that we added, and it offers a +reasonable blend of the benefits of the other two. - It is a coarse-grained named-op approach. - It has a pretty clear spec for most of the ops (with a bit of mental - translation and hoping that MHLO is the same as HLO): + translation and hoping that StableHLO is the same as HLO): https://www.tensorflow.org/xla/operation_semantics - It functionally supports dynamic shapes (though not as coherent and consistent as Linalg-on-Tensors, and the dynamic shape support falls outside the @@ -317,7 +317,7 @@ blend of the benefits of the other two. example, TOSA limits (for highly considered reasons) the number of dimensions that certain operators can handle to 1D-4D, when from a purely algebraic perspective there isn't a good reason to not be more general. Similarly, more - general forms of reduction and scatter also fall into MHLO nicely while + general forms of reduction and scatter also fall into StableHLO nicely while TOSA's principles tend to bias it away from that. ### Backend Implementation @@ -433,8 +433,9 @@ filling in some corners missing upstream and to pull together upstream functionality into a working system. The RefBackend accepts Linalg-on-Tensors as input. It mainly just bufferizes the -ops and lowers them to loops. Note that TOSA and MHLO support lowering to -Linalg-on-Tensors, so all our end-to-end testing bottoms out on RefBackend. +ops and lowers them to loops. Note that TOSA and StableHLO (via MHLO) support +lowering to Linalg-on-Tensors, so all our end-to-end testing bottoms out on +RefBackend. The RefBackend is absolutely not suitable for any production use case. It leaks memory, doesn't support any error handling, performs no optimizations, and diff --git a/docs/code_owners.md b/docs/code_owners.md index 299f19656978..d70b236b9911 100644 --- a/docs/code_owners.md +++ b/docs/code_owners.md @@ -34,7 +34,7 @@ and Clang's - Eric Kunze (@eric-k256) - Suraj Sudhir (@sjarus) -### TorchToMHLO +### TorchToStablehlo - Tianyo Kwok (@tanyokwok) - Ziheng Jiang (@ZihengJiang) diff --git a/docs/development.md b/docs/development.md index f6e976769eca..048f363c0763 100644 --- a/docs/development.md +++ b/docs/development.md @@ -139,7 +139,7 @@ Ex: module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") ``` -Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `MHLO`. +Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. ## Jupyter diff --git a/docs/long_term_roadmap.md b/docs/long_term_roadmap.md index 1e8981da140e..62c3b6f94171 100644 --- a/docs/long_term_roadmap.md +++ b/docs/long_term_roadmap.md @@ -46,7 +46,7 @@ the ecosystem are: - The frontend work required to lower TorchScript to the backend contract. - The irregular support surface area of the large number of PyTorch ops across - the Linalg, TOSA, and MHLO backends. + the Linalg, TOSA, and StableHLO backends. Most of this document describes long-term ecosystem changes that will address these, drastically improving Torch-MLIR's ability to meet its goals. @@ -108,7 +108,7 @@ more advanced). ### Refactoring the backend Today in Torch-MLIR, we support 3 backends out of the box: Linalg-on-Tensors, -TOSA, and MHLO. These backends take IR in the backend contract form (see +TOSA, and StableHLO. These backends take IR in the backend contract form (see [architecture.md](architecture.md)) and lowers them to the respective dialects. Today, each backend is implemented completely independently. This leads to duplication and irregularity across the backends. @@ -120,12 +120,10 @@ lowering of so many ops across backends. Additionally, there are 3 forward-looking efforts that intersect with this effort: - [StableHLO](https://github.com/openxla/stablehlo) - this is a dialect - initially forked from MHLO which intends to create a stable support surface - area for what today is our "at head" dependency on MHLO. MHLO is a fairly - complete op set, so it is very attractive to have "almost all" models - bottleneck through a stable interface like StableHLO. StableHLO is currently - under relatively early development, but already delivers on many of the goals - of stability. + initially forked from MHLO. MHLO is a fairly complete op set, so it is very + attractive to have "almost all" models bottleneck through a stable interface + like StableHLO. StableHLO is currently under relatively early development, + but already delivers on many of the goals of stability. - [TCP](https://github.com/llvm/torch-mlir/issues/1366) - this is a dialect which could serve a role very similar to MHLO, while providing community ownership. TCP is still in early planning phases, but there is strong diff --git a/e2e_testing/main.py b/e2e_testing/main.py index d48223ad4aa9..770d32ca54b9 100644 --- a/e2e_testing/main.py +++ b/e2e_testing/main.py @@ -16,7 +16,7 @@ from torch_mlir_e2e_test.configs import ( LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, - MhloBackendTestConfig, + StablehloBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, @@ -24,17 +24,17 @@ ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend -from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend +from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import LINALG_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET +from .xfail_sets import LINALG_XFAIL_SET, STABLEHLO_PASS_SET, TOSA_PASS_SET, LTC_XFAIL_SET, TORCHDYNAMO_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "mhlo", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -42,7 +42,7 @@ def _get_argparse(): help=f""" Meaning of options: "linalg": run through torch-mlir"s default Linalg-on-Tensors backend. -"mhlo": run through torch-mlir"s default MHLO backend. +"stablehlo": run through torch-mlir"s default StableHLO backend. "tosa": run through torch-mlir"s default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). @@ -80,9 +80,9 @@ def main(): if args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET - if args.config == "mhlo": - config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) - xfail_set = all_test_unique_names - MHLO_PASS_SET + if args.config == "stablehlo": + config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) + xfail_set = all_test_unique_names - STABLEHLO_PASS_SET elif args.config == "native_torch": config = NativeTorchTestConfig() xfail_set = {} diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index 52fe3467ba5f..29d8437c8acf 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -87,8 +87,10 @@ "StdCorrectionKeepDimModule_basic", } -MHLO_PASS_SET = { +STABLEHLO_PASS_SET = { "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AddSizeIntModule_basic", + "AddSizeIntNegDimModule_basic", "ArangeDtypeFloatModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", @@ -103,6 +105,7 @@ "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "BatchMlpLayerModule_basic", "BmmModule_basic", "BroadcastToModule_basic", "BroadcastToSameRankStaticModule_basic", @@ -124,12 +127,15 @@ "ElementwiseClampMinModule_basic", "ElementwiseClampMaxModule_basic", "ElementwiseExpModule_basic", + "ElementwiseFlattenBroadcastModule_basic", + "ElementwiseLeakyReluModule_basic", "ElementwiseLogModule_basic", "ElementwiseNegModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSqrtModule_basic", "ElementwiseUnaryModule_basic", + "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", "ElementwiseAddModule_basic", @@ -198,6 +204,8 @@ "Gather2DInputModdule_basic", "GatherRandomIndexModule_basic", "GeluBackwardModule_basic", + "HardswishModule_basic", + "HardswishRandomModule_basic", "HardTanhIntModule_basic", "HardTanhModule_basic", "HardsigmoidModule_basic", @@ -220,6 +228,8 @@ "MeanDynamicSizesModule_basic", "MeanLargeInputModule_basic", "MeanModule_basic", + "Mlp1LayerModule_basic", + "Mlp2LayerModule_basic", "MmTanhModule_basic", "Mv_basic", "NativeLayerNormModule4D_basic", @@ -251,6 +261,8 @@ "LiftFreshCopyModule_basic", "Mlp2LayerModuleNoBias_basic", "NumelModule_basic", + "SiluModule_basic", + "SquareModule_basic", "SqueezeModule_allUnitDim", "SqueezeDimModule_unitDim", "ViewCollapseOnesMiddleModule_basic", @@ -419,6 +431,7 @@ "UnsafeViewDynamicExpandModule_basic", "AtenRoundIntModule_basic", "TestF16Return_basic", + "_LogSoftmaxModuleStable_basic", } # Write the TOSA set as a "passing" set as it is very early in development diff --git a/examples/torchscript_mhlo_backend_resnet.py b/examples/torchscript_mhlo_backend_resnet.py deleted file mode 100644 index bb481f6c3366..000000000000 --- a/examples/torchscript_mhlo_backend_resnet.py +++ /dev/null @@ -1,14 +0,0 @@ -import torch -import torchvision.models as models -import torch_mlir - -model = models.resnet18(pretrained=True) -model.eval() -data = torch.randn(2,3,200,200) -out_mhlo_mlir_path = "./resnet18_mhlo.mlir" - -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=False) -with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: - outf.write(str(module)) - -print(f"MHLO IR of resent18 successfully written into {out_mhlo_mlir_path}") diff --git a/examples/torchscript_stablehlo_backend_resnet.py b/examples/torchscript_stablehlo_backend_resnet.py new file mode 100644 index 000000000000..7a97359cff62 --- /dev/null +++ b/examples/torchscript_stablehlo_backend_resnet.py @@ -0,0 +1,14 @@ +import torch +import torchvision.models as models +import torch_mlir + +model = models.resnet18(pretrained=True) +model.eval() +data = torch.randn(2,3,200,200) +out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" + +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False) +with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: + outf.write(str(module)) + +print(f"StableHLO IR of resent18 successfully written into {out_stablehlo_mlir_path}") diff --git a/examples/torchscript_mhlo_backend_tinybert.py b/examples/torchscript_stablehlo_backend_tinybert.py similarity index 69% rename from examples/torchscript_mhlo_backend_tinybert.py rename to examples/torchscript_stablehlo_backend_tinybert.py index 62827361e84f..c035be3a54fe 100644 --- a/examples/torchscript_mhlo_backend_tinybert.py +++ b/examples/torchscript_stablehlo_backend_tinybert.py @@ -15,10 +15,10 @@ def forward(self, data): model = BertTinyWrapper() model.eval() data = torch.randint(30522, (2, 128)) -out_mhlo_mlir_path = "./bert_tiny_mhlo.mlir" +out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.MHLO, use_tracing=True) -with open(out_mhlo_mlir_path, "w", encoding="utf-8") as outf: +module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) +with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) -print(f"MHLO IR of tiny bert successfully written into {out_mhlo_mlir_path}") +print(f"StableHLO IR of tiny bert successfully written into {out_stablehlo_mlir_path}") diff --git a/include/torch-mlir/Conversion/CMakeLists.txt b/include/torch-mlir/Conversion/CMakeLists.txt index 9ee80b304b66..d6552314999b 100644 --- a/include/torch-mlir/Conversion/CMakeLists.txt +++ b/include/torch-mlir/Conversion/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_MHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) +if(TORCH_MLIR_ENABLE_STABLEHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) else() mlir_tablegen(Passes.h.inc -gen-pass-decls) endif() diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 7072b8d5f416..b5f30bfbe724 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -133,13 +133,13 @@ def ConvertTorchConversionToMLProgram : Pass<"convert-torch-conversion-to-mlprog let constructor = "mlir::torch::createConvertTorchConversionToMLProgramPass()"; } -#ifdef TORCH_MLIR_ENABLE_MHLO -def ConvertTorchToMhlo : Pass<"convert-torch-to-mhlo", "func::FuncOp"> { - let summary = "Convert Torch ops to MHLO ops"; +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +def ConvertTorchToStablehlo : Pass<"convert-torch-to-stablehlo", "func::FuncOp"> { + let summary = "Convert Torch ops to Stablehlo ops"; let description = [{ - Convert Torch ops to mhlo ops. + Convert Torch ops to Stablehlo ops. }]; - let constructor = "mlir::torch::createConvertTorchToMhloPass()"; + let constructor = "mlir::torch::createConvertTorchToStablehloPass()"; // Specify any options. let options = [ diff --git a/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h similarity index 64% rename from include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h rename to include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h index 8e2f5fc8630d..c1926015989e 100644 --- a/include/torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h +++ b/include/torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H -#define TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H +#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H +#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Pass/Pass.h" @@ -16,10 +16,11 @@ namespace mlir { namespace torch { -std::unique_ptr> createConvertTorchToMhloPass(); std::unique_ptr> -createConvertTorchToMhloPass(bool enableStaticShape, bool enableI32Index); +createConvertTorchToStablehloPass(); +std::unique_ptr> +createConvertTorchToStablehloPass(bool enableStaticShape, bool enableI32Index); } // namespace torch } // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_TORCHTOMHLO_H +#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_TORCHTOSTABLEHLO_H diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt index 00818899f775..77e46eb4be04 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -1,6 +1,6 @@ set(LLVM_TARGET_DEFINITIONS Passes.td) -if(TORCH_MLIR_ENABLE_MHLO) - mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_MHLO) +if(TORCH_MLIR_ENABLE_STABLEHLO) + mlir_tablegen(Passes.h.inc -gen-pass-decls -DTORCH_MLIR_ENABLE_STABLEHLO) else() mlir_tablegen(Passes.h.inc -gen-pass-decls) endif() diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index fd350da1d61e..e6493a154edd 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -30,10 +30,10 @@ void createTorchBackendToLinalgOnTensorsBackendPipeline(OpPassManager &pm); /// TOSA backend contract. void createTorchBackendToTosaBackendPipeline(OpPassManager &pm); -// Do not register the torch-to-mhlo pipeline if mhlo target is disabled -#ifdef TORCH_MLIR_ENABLE_MHLO -struct MhloBackendPipelineOptions - : public PassPipelineOptions { +// Do not register the stablehlo options if the stablehlo target is disabled +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +struct StablehloBackendPipelineOptions + : public PassPipelineOptions { Option enableStaticShape{ *this, "enable-static-shape", llvm::cl::desc("Enable static shape conversion."), llvm::cl::init(false)}; @@ -46,9 +46,10 @@ struct MhloBackendPipelineOptions llvm::cl::init(false)}; }; -void createTorchBackendToMhloBackendPipeline( - OpPassManager &pm, const MhloBackendPipelineOptions &options); -std::unique_ptr> createVerifyMhloBackendContractPass(); +void createTorchBackendToStablehloBackendPipeline( + OpPassManager &pm, const StablehloBackendPipelineOptions &options); +std::unique_ptr> +createVerifyStablehloBackendContractPass(); #endif std::unique_ptr> createFuncBackendTypeConversionPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4ce7cdadbb39..cb58dbbd998b 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -42,10 +42,10 @@ def VerifyTosaBackendContract : Pass<"torch-verify-tosa-backend-contract", "Modu let constructor = "mlir::torch::TorchConversion::createVerifyTosaBackendContractPass()"; } -#ifdef TORCH_MLIR_ENABLE_MHLO -def VerifyMhloBackendContract : Pass<"torch-verify-mhlo-backend-contract", "ModuleOp"> { - let summary = "Verifies conformity to the mhlo backend contract"; - let constructor = "mlir::torch::TorchConversion::createVerifyMhloBackendContractPass()"; +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contract", "ModuleOp"> { + let summary = "Verifies conformity to the stablehlo backend contract"; + let constructor = "mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass()"; } -#endif // TORCH_MLIR_ENABLE_MHLO +#endif // TORCH_MLIR_ENABLE_STABLEHLO #endif // TORCHMLIR_TORCHCONVERSION_PASSES diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index ec6ee8cee77a..4c37cca5efb4 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -3,13 +3,7 @@ add_subdirectory(Conversion) add_subdirectory(Dialect) add_subdirectory(RefBackend) -add_mlir_library(TorchMLIRInitAll - InitAll.cpp - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC +set(LinkedLibs MLIRFuncDialect MLIRIR MLIRSupport @@ -27,4 +21,22 @@ add_mlir_library(TorchMLIRInitAll TorchMLIRRefBackend ) +if(TORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND LinkedLibs + MhloPasses + MhloToLinalg + StablehloToMhlo + ) +endif() + +add_mlir_library(TorchMLIRInitAll + InitAll.cpp + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + ${LinkedLibs} +) + torch_mlir_target_includes(TorchMLIRInitAll) diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index 29812d1feed4..d72563b1e697 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,8 +2,8 @@ add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) add_subdirectory(TorchToTosa) -if(TORCH_MLIR_ENABLE_MHLO) - add_subdirectory(TorchToMhlo) +if(TORCH_MLIR_ENABLE_STABLEHLO) + add_subdirectory(TorchToStablehlo) endif() add_subdirectory(TorchToTMTensor) add_subdirectory(TorchConversionToMLProgram) @@ -17,10 +17,8 @@ set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchToTMTensor TorchMLIRTorchConversionToMLProgram TorchMLIRConversionUtils) -if(TORCH_MLIR_ENABLE_MHLO) - list(APPEND linked_libs - MhloPasses - TorchMLIRTorchToMhlo) +if(TORCH_MLIR_ENABLE_STABLEHLO) + list(APPEND linked_libs TorchMLIRTorchToStablehlo) endif() add_mlir_library(TorchMLIRConversionPasses diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index f07a3afb3002..45714601ded0 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,15 +9,15 @@ #include "torch-mlir/Conversion/Passes.h" -#ifdef TORCH_MLIR_ENABLE_MHLO -#include "mhlo/transforms/passes.h" +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "transforms/passes.h" -#endif // TORCH_MLIR_ENABLE_MHLO +#endif // TORCH_MLIR_ENABLE_STABLEHLO + #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" @@ -32,12 +32,4 @@ namespace { void mlir::torch::registerConversionPasses() { ::registerPasses(); -#ifdef TORCH_MLIR_ENABLE_MHLO - ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { - return mlir::mhlo::createLegalizeHloToLinalgPass(); - }); - ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { - return mlir::mhlo::createSymbolicShapeOptimizationPass(); - }); -#endif // TORCH_MLIR_ENABLE_MHLO } diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt deleted file mode 100644 index 4c0929268ac9..000000000000 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ /dev/null @@ -1,35 +0,0 @@ -add_mlir_conversion_library(TorchMLIRTorchToMhlo - TorchToMhlo.cpp - MhloLegalizeUtils.cpp - Basic.cpp - Gather.cpp - Linear.cpp - ViewLike.cpp - Reduction.cpp - Pooling.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToMhlo - - DEPENDS - MhloDialect - MhloToLinalg - MLIRMhloPassIncGen - LMHLOTransformsPassIncGen - TorchMLIRConversionPassIncGen - - LINK_COMPONENTS - Core - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - MhloDialect - MhloToLinalg - MLIRBufferTransforms - StablehloOps - TorchMLIRTorchDialect - TorchMLIRConversionUtils -) - -torch_mlir_target_includes(TorchMLIRTorchToMhlo) diff --git a/lib/Conversion/TorchToMhlo/PopulatePatterns.h b/lib/Conversion/TorchToMhlo/PopulatePatterns.h deleted file mode 100644 index 2e195a87fb77..000000000000 --- a/lib/Conversion/TorchToMhlo/PopulatePatterns.h +++ /dev/null @@ -1,74 +0,0 @@ -//===------------------------------------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// Also available under a BSD-style license. See LICENSE. -// -//===----------------------------------------------------------------------===// - -#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H -#define TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H - -#include "mlir/Transforms/DialectConversion.h" - -namespace mlir { -namespace torch { -namespace torch_to_mhlo { - -struct TorchToMhloOptions { - bool enableStaticShape = false; - size_t dimSizeIndexBits = 64; -}; - -template -class ConvertAtenOp : public OpConversionPattern { -public: - using OpAdaptor = typename AtenOpT::Adaptor; - ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context, - const TorchToMhloOptions &options) - : OpConversionPattern(typeConverter, context) { - this->options = options; - } - LogicalResult - matchAndRewrite(AtenOpT op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - return rewriter.notifyMatchFailure(op, "haven't been implemented"); - } - const TorchToMhloOptions &getOptions() const { return options; } - -private: - TorchToMhloOptions options; -}; - -void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToMhloOptions &options); -void populateViewLikeOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToMhloOptions &options); -void populateGatherOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToMhloOptions &options); -void populateReductionOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToMhloOptions &options); -void populateLinearOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToMhloOptions &options); - -void populatePoolingOpPatternsAndLegality(TypeConverter &typeConverter, - RewritePatternSet &patterns, - ConversionTarget &target, - const TorchToMhloOptions &options); - -} // namespace torch_to_mhlo -} // namespace torch -} // namespace mlir - -#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOMHLO_POPULATEPATTERNS_H diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp similarity index 82% rename from lib/Conversion/TorchToMhlo/Basic.cpp rename to lib/Conversion/TorchToStablehlo/Basic.cpp index 35776bb88f9d..d82826d714cf 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -7,15 +7,16 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./MhloLegalizeUtils.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -29,7 +30,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_mhlo; +using namespace mlir::torch::torch_to_stablehlo; LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, mlir::Value &self, mlir::Value &other, @@ -43,16 +44,16 @@ LogicalResult broadcastRanks(PatternRewriter &rewriter, Operation *op, if (selfRank > otherRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, selfRank - otherRank)); - auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, other, - unsqueezeDims, dimSizeIndexBits); + auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, other, + unsqueezeDims, dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); other = *unsqueezeInfo; } else if (otherRank > selfRank) { auto unsqueezeDims = llvm::to_vector<4>(llvm::seq(0, otherRank - selfRank)); - auto unsqueezeInfo = mhlo::unsqueezeTensor(rewriter, op, self, - unsqueezeDims, dimSizeIndexBits); + auto unsqueezeInfo = hlo::unsqueezeTensor(rewriter, op, self, unsqueezeDims, + dimSizeIndexBits); if (failed(unsqueezeInfo)) return failure(); self = *unsqueezeInfo; @@ -78,7 +79,8 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/false)); - return rewriter.create(op->getLoc(), constType, constAttr) + return rewriter + .create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { @@ -91,7 +93,8 @@ static FailureOr getMaxValueOfDtype(Operation *op, Type elementType, constAttr = SplatElementsAttr::get( constType, APInt::getSignedMaxValue(integerType.getWidth())); } - return rewriter.create(op->getLoc(), constType, constAttr) + return rewriter + .create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); @@ -105,7 +108,8 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, constType, APFloat::getInf(elementType.cast().getFloatSemantics(), /*negative=*/true)); - return rewriter.create(op->getLoc(), constType, constAttr) + return rewriter + .create(op->getLoc(), constType, constAttr) .getResult(); } if (elementType.isa()) { @@ -118,7 +122,8 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, constAttr = SplatElementsAttr::get( constType, APInt::getSignedMinValue(integerType.getWidth())); } - return rewriter.create(op->getLoc(), constType, constAttr) + return rewriter + .create(op->getLoc(), constType, constAttr) .getResult(); } return failure(); @@ -126,7 +131,7 @@ static FailureOr getMinValueOfDtype(Operation *op, Type elementType, // These legalizations are for unary ops. namespace { -template +template class ConvertAtenUnaryOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -137,13 +142,13 @@ class ConvertAtenUnaryOp : public OpConversionPattern { Value self = adaptor.getSelf(); auto selfType = self.getType().cast(); if (!selfType) { - return op.emitError("only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in StableHLO"); } auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - self = mhlo::promoteType(rewriter, self, outType); - rewriter.replaceOpWithNewOp(op, outType, self); + self = hlo::promoteType(rewriter, self, outType); + rewriter.replaceOpWithNewOp(op, outType, self); return success(); } }; @@ -152,7 +157,7 @@ class ConvertAtenUnaryOp : public OpConversionPattern { // These legalizations are for unary ops with only for floating point datatypes. // There is no supported quantized integer mode for these. namespace { -template +template class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -164,10 +169,10 @@ class ConvertAtenUnaryFPOnlyOp : public OpConversionPattern { auto selfTy = self.getType().cast(); if (!selfTy) - return op.emitError("only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in StableHLO"); if (selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), @@ -198,7 +203,7 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { .template dyn_cast(); if (!outType) - return op.emitError("only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in StableHLO"); Type outElemTy = outType.getElementType(); if (!outElemTy.isIntOrFloat()) @@ -216,9 +221,9 @@ class ConvertAtenConstPatternOp : public OpConversionPattern { SmallVector values(size, fillVal); auto constOp = - mhlo::getConstTensor(rewriter, op, values, shape).value(); + hlo::getConstTensor(rewriter, op, values, shape).value(); - rewriter.replaceOpWithNewOp(op, outType, constOp); + rewriter.replaceOpWithNewOp(op, outType, constOp); return success(); } }; @@ -247,8 +252,8 @@ class ConvertAtenBinaryBroadcastOp : public OpConversionPattern { ->convertType(op.getType()) .template cast(); - lhs = mhlo::promoteType(rewriter, lhs, outTy); - rhs = mhlo::promoteType(rewriter, rhs, outTy); + lhs = hlo::promoteType(rewriter, lhs, outTy); + rhs = hlo::promoteType(rewriter, rhs, outTy); rewriter.replaceOpWithNewOp(op, outTy, lhs, rhs, /*broadcast_attr*/ nullptr); @@ -274,7 +279,7 @@ class ConvertAtenAddSubOp : public OpConversionPattern { RankedTensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in StableHLO"); TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -287,18 +292,19 @@ class ConvertAtenAddSubOp : public OpConversionPattern { } if (!rhsType) { - rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy); + rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), + outElemTy); if (isa(op)) { std::swap(lhs, rhs); } } - lhs = mhlo::promoteType(rewriter, lhs, outType); - rhs = mhlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, lhs, outType); + rhs = hlo::promoteType(rewriter, rhs, outType); if (!skipMultiplyAlpha(op.getAlpha())) { - Value alpha = - mhlo::scalarToMhloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); + Value alpha = hlo::scalarToStablehloTensor(rewriter, op, + adaptor.getAlpha(), outElemTy); DenseIntElementsAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); @@ -328,7 +334,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern { TensorType rhsType = rhs.getType().dyn_cast(); if (!lhsType) - return op.emitError("only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in StableHLO"); auto outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -343,11 +349,12 @@ class ConvertAtenMulDivOp : public OpConversionPattern { if (std::is_same()) { rhs = lhs; } else if (!rhsType) { - rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), outElemTy); + rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), + outElemTy); } DenseIntElementsAttr bcastDimensions; - lhs = mhlo::promoteType(rewriter, lhs, outType); - rhs = mhlo::promoteType(rewriter, rhs, outType); + lhs = hlo::promoteType(rewriter, lhs, outType); + rhs = hlo::promoteType(rewriter, rhs, outType); auto loc = op.getLoc(); Value result = rewriter.create(loc, outType, lhs, rhs, bcastDimensions); @@ -368,15 +375,15 @@ class ConvertAtenMulDivOp : public OpConversionPattern { if (roundingMode == "trunc") { // "trunc" - rounds the results of the division towards zero. Equivalent // to C-style integer division. - auto sign = rewriter.create(loc, result); - auto abs = rewriter.create(loc, result); - auto floor = rewriter.create(loc, abs); - result = rewriter.create(loc, sign, floor).getResult(); + auto sign = rewriter.create(loc, result); + auto abs = rewriter.create(loc, result); + auto floor = rewriter.create(loc, abs); + result = rewriter.create(loc, sign, floor).getResult(); } if (roundingMode == "floor") { // "floor" - rounds the results of the division down. Equivalent to // floor division in Python (the // operator) - result = rewriter.create(loc, result).getResult(); + result = rewriter.create(loc, result).getResult(); } rewriter.replaceOp(op, result); return success(); @@ -401,7 +408,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { RankedTensorType rhsTy = rhs.getType().dyn_cast(); if (!lhsTy) - return op.emitError("only Tensor types supported in MHLO"); + return op.emitError("only Tensor types supported in StableHLO"); RankedTensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) @@ -414,11 +421,12 @@ class ConvertAtenCompareOp : public OpConversionPattern { } if (!rhsTy) { - rhs = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getOther(), lhsElemTy); + rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), + lhsElemTy); } // TODO: what is the PyTorch default type promotion? - rhs = mhlo::promoteType(rewriter, rhs, lhsTy); + rhs = hlo::promoteType(rewriter, rhs, lhsTy); chlo::ComparisonTypeAttr compareTypeAttr; chlo::ComparisonDirectionAttr compareDirectionAttr; @@ -485,8 +493,8 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { TensorType outType = OpConversionPattern::getTypeConverter() ->convertType(op.getType()) .template cast(); - Value lhs = mhlo::promoteType(rewriter, adaptor.getSelf(), outType); - Value rhs = mhlo::promoteType(rewriter, adaptor.getOther(), outType); + Value lhs = hlo::promoteType(rewriter, adaptor.getSelf(), outType); + Value rhs = hlo::promoteType(rewriter, adaptor.getOther(), outType); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, @@ -537,8 +545,8 @@ class ConvertAtenTransposeIntOp RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); - rewriter.replaceOpWithNewOp(op, outType, self, - permutation); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); return success(); } }; @@ -552,7 +560,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value self = adaptor.getSelf(); auto outType = getTypeConverter()->convertType(op.getType()).cast(); - rewriter.replaceOpWithNewOp(op, outType, self); + rewriter.replaceOpWithNewOp(op, outType, self); return success(); } @@ -573,7 +581,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { Value inputRank = rewriter.create( op.getLoc(), rewriter.getI64IntegerAttr(selfType.getRank())); - dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), inputRank); + dim = toPositiveDimDynamic(rewriter, op.getLoc(), adaptor.getDim(), + inputRank); dim = rewriter.create(op.getLoc(), rewriter.getIndexType(), dim); } @@ -589,9 +598,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( template <> LogicalResult ConvertAtenOp::matchAndRewrite( - AtenWhereSelfOp op, - OpAdaptor adaptor, - ConversionPatternRewriter& rewriter) const { + AtenWhereSelfOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const { Value self = adaptor.getSelf(); Value cond = adaptor.getCondition(); Value other = adaptor.getOther(); @@ -605,8 +613,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return op.emitError("failed broadcast other and condition ranks"); rewriter.replaceOpWithNewOp( - op, - getTypeConverter()->convertType(op.getType()), + op, getTypeConverter()->convertType(op.getType()), ArrayRef{cond, self, other}); return success(); } @@ -623,7 +630,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( .cast(); if (options.enableStaticShape && selfTy.hasStaticShape()) { - Value bcastOp = mhlo::promoteAndBroadcast(rewriter, self, outType); + Value bcastOp = hlo::promoteAndBroadcast(rewriter, self, outType); rewriter.replaceOp(op, bcastOp); return success(); } @@ -670,7 +677,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op->getLoc(), ValueRange{bcastShapeVec}); auto dimensionNumbers = llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, rewriter.getI64TensorAttr(dimensionNumbers)); } @@ -708,8 +715,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( RankedTensorType::get({static_cast(permValues.size())}, rewriter.getI64Type()), permValues); - rewriter.replaceOpWithNewOp(op, outType, self, - permutation); + rewriter.replaceOpWithNewOp(op, outType, self, + permutation); return success(); } @@ -721,7 +728,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value self = adaptor.getSelf(); auto selfTy = self.getType().cast(); if (selfTy && selfTy.getElementType().isa()) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } else { @@ -751,16 +758,16 @@ LogicalResult ConvertAtenOp::matchAndRewrite( elements.mapValues(builtinTensorElemTy, [&](const APInt &v) { return APInt(bitWidth, v.getSExtValue()); }); - rewriter.replaceOpWithNewOp(op, resultType, valueAttr); + rewriter.replaceOpWithNewOp(op, resultType, + valueAttr); return success(); } - rewriter.replaceOpWithNewOp(op, resultType, - adaptor.getValue()); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getValue()); return success(); } - // AtenReciprocalOp // Reciprocal(x) = Div(1, x) template <> @@ -777,7 +784,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } Value oneTensor = chlo::getConstantLike(rewriter, op->getLoc(), 1, input); - rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); + rewriter.replaceOpWithNewOp(op, outTy, oneTensor, input); return success(); } @@ -790,9 +797,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ->convertType(op->getResult(0).getType()) .cast(); auto outputElemType = outputType.getElementType(); - Value mhloTensor = - mhlo::scalarToMhloTensor(rewriter, op, adaptor.getA(), outputElemType); - rewriter.replaceOp(op, mhloTensor); + Value stablehloTensor = hlo::scalarToStablehloTensor( + rewriter, op, adaptor.getA(), outputElemType); + rewriter.replaceOp(op, stablehloTensor); return success(); } @@ -815,7 +822,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - // AtenReluOp // Relu(x) = Max(0, x) template <> @@ -836,11 +842,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( APFloat::getZero(lhsElemTy.cast().getFloatSemantics(), false), lhs); - rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); + rewriter.replaceOpWithNewOp(op, lhs, zeroTensor); return success(); } - // Convert a Aten::GELU to HLO // Gelu(x) = x * 1/2 * [1 + erf(x/(sqrt(2)))] template <> @@ -857,12 +862,12 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value one = chlo::getConstantLike(rewriter, loc, 1.0, input); Value two = chlo::getConstantLike(rewriter, loc, 2.0, input); Value half = chlo::getConstantLike(rewriter, loc, 0.5, input); - auto rsqrtTwo = rewriter.create(loc, two); - auto erfElement = rewriter.create(loc, input, rsqrtTwo); + auto rsqrtTwo = rewriter.create(loc, two); + auto erfElement = rewriter.create(loc, input, rsqrtTwo); auto erf = rewriter.create(loc, erfElement); - auto erfAdd = rewriter.create(loc, erf, one); - auto halfMul = rewriter.create(loc, erfAdd, half); - rewriter.replaceOpWithNewOp(op, input, halfMul); + auto erfAdd = rewriter.create(loc, erf, one); + auto halfMul = rewriter.create(loc, erfAdd, half); + rewriter.replaceOpWithNewOp(op, input, halfMul); return success(); } @@ -881,7 +886,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - // AtenBatchNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -919,28 +923,28 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value channelShape = rewriter.create( op->getLoc(), ValueRange{channelDim}); if (failed(checkNotNone(rewriter, op, weight))) { - weight = mhlo::getConstantOfShape( + weight = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, bias))) { - bias = mhlo::getConstantOfShape( + bias = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningVar))) { - runningVar = mhlo::getConstantOfShape( + runningVar = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 1), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, inputTy.getElementType())); } if (failed(checkNotNone(rewriter, op, runningMean))) { - runningMean = mhlo::getConstantOfShape( + runningMean = hlo::getConstantOfShape( rewriter, op->getLoc(), APFloat(inputElemTy.getFloatSemantics(), 0), channelShape, RankedTensorType::get({inputTy.getShape()[1]}, @@ -983,10 +987,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Type outputTy = getTypeConverter()->convertType(op.getType()); Type batchMeanOrVarTy = RankedTensorType::get(weightTy.getShape(), inputTy.getElementType()); - auto batchNormTrainingResult = rewriter.create( - op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, - weight, bias, rewriter.getF32FloatAttr(eps), - rewriter.getI64IntegerAttr(1)); + auto batchNormTrainingResult = + rewriter.create( + op.getLoc(), outputTy, batchMeanOrVarTy, batchMeanOrVarTy, input, + weight, bias, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(1)); rewriter.replaceOp(op, batchNormTrainingResult.getResult(0)); return success(); } else { @@ -995,10 +1000,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( inputTy.getShape().end()}; castShape[1] = weightTy.getShape()[0]; auto castTy = RankedTensorType::get(castShape, inputTy.getElementType()); - // Feature counts must match among operands of mhlo::BatchNormInferenceOp. + // Feature counts must match among operands of + // stablehlo::BatchNormInferenceOp. Value inputCasted = rewriter.create(op.getLoc(), castTy, input); - Value output = rewriter.create( + Value output = rewriter.create( op.getLoc(), inputCasted.getType(), inputCasted, weight, bias, runningMean, runningVar, // 'epsilon' must satisfy constraint: 32-bit float attribute. @@ -1008,7 +1014,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } } - // AtenNativeLayerNormOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1076,21 +1081,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } SmallVector inputFlattenShape{1, numFeatureDimSize, numEmbeddingDimSize}; - SmallVector meanOrVarMhloOutShape{numFeatureDimSize}; + SmallVector meanOrVarStablehloOutShape{numFeatureDimSize}; - auto mhloBatchNormOutTy = + auto stablehloBatchNormOutTy = RankedTensorType::get(inputFlattenShape, inputTy.getElementType()); - auto mhloBathNormOutMeanOrVarTy = - RankedTensorType::get(meanOrVarMhloOutShape, inputTy.getElementType()); + auto stablehloBathNormOutMeanOrVarTy = RankedTensorType::get( + meanOrVarStablehloOutShape, inputTy.getElementType()); // Reshape input - auto mhloInput = rewriter.create( - op->getLoc(), mhloBatchNormOutTy, input, - mhlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), - {static_cast(inputFlattenShape.size())}) + auto stablehloInput = rewriter.create( + op->getLoc(), stablehloBatchNormOutTy, input, + hlo::getConstTensor(rewriter, op, llvm::ArrayRef(inputFlattenShape), + {static_cast(inputFlattenShape.size())}) .value()); - // Generate "scale" and "offset" Value for mhlo.BatchNormTrainingOp. + // Generate "scale" and "offset" Value for stablehlo.BatchNormTrainingOp. SmallVector zeroConstVec( numFeatureDimSize, APFloat::getZero(inputTy.getElementType() .cast() @@ -1103,16 +1108,18 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto oneOrZeroConstType = RankedTensorType::get({numFeatureDimSize}, inputTy.getElementType()); - Value scale = rewriter.create( + Value scale = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, oneConstVec)); - Value offset = rewriter.create( + Value offset = rewriter.create( op->getLoc(), oneOrZeroConstType, DenseElementsAttr::get(oneOrZeroConstType, zeroConstVec)); - auto batchNormTrainingResult = rewriter.create( - op->getLoc(), mhloBatchNormOutTy, mhloBathNormOutMeanOrVarTy, - mhloBathNormOutMeanOrVarTy, mhloInput, scale, offset, - rewriter.getF32FloatAttr(eps), rewriter.getI64IntegerAttr(1)); + auto batchNormTrainingResult = + rewriter.create( + op->getLoc(), stablehloBatchNormOutTy, + stablehloBathNormOutMeanOrVarTy, stablehloBathNormOutMeanOrVarTy, + stablehloInput, scale, offset, rewriter.getF32FloatAttr(eps), + rewriter.getI64IntegerAttr(1)); // Reshape back auto outputTy = @@ -1120,36 +1127,35 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto outputMeanOrVarTy = getTypeConverter()->convertType(op.getType(1)).cast(); - auto output = rewriter.create( + auto output = rewriter.create( op->getLoc(), outputTy, batchNormTrainingResult.getResult(0), - mhlo::getConstTensor(rewriter, op, outputTy.getShape(), - {static_cast(outputTy.getShape().size())}) + hlo::getConstTensor(rewriter, op, outputTy.getShape(), + {static_cast(outputTy.getShape().size())}) .value()); - auto mean = rewriter.create( + auto mean = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(1), - mhlo::getConstTensor( + hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); - auto var = rewriter.create( + auto var = rewriter.create( op->getLoc(), outputMeanOrVarTy, batchNormTrainingResult.getResult(2), - mhlo::getConstTensor( + hlo::getConstTensor( rewriter, op, outputMeanOrVarTy.getShape(), {static_cast(outputMeanOrVarTy.getShape().size())}) .value()); // Apply affine transform: output x weight + bias [element-wise] - auto bcastedWeight = mhlo::promoteAndBroadcast(rewriter, weight, outputTy); - auto bcastedBias = mhlo::promoteAndBroadcast(rewriter, bias, outputTy); + auto bcastedWeight = hlo::promoteAndBroadcast(rewriter, weight, outputTy); + auto bcastedBias = hlo::promoteAndBroadcast(rewriter, bias, outputTy); auto outputMulWeight = - rewriter.create(op->getLoc(), output, bcastedWeight); - auto finalOuput = - rewriter.create(op->getLoc(), outputMulWeight, bcastedBias); + rewriter.create(op->getLoc(), output, bcastedWeight); + auto finalOuput = rewriter.create( + op->getLoc(), outputMulWeight, bcastedBias); rewriter.replaceOp(op, {finalOuput, mean, var}); return success(); } - // AtenCatOp template <> LogicalResult ConvertAtenOp::matchAndRewrite( @@ -1173,11 +1179,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Promote type for (auto &v : builtinTensors) { - v = mhlo::promoteType(rewriter, v, outType); + v = hlo::promoteType(rewriter, v, outType); } size_t posDim = toPositiveDim(dim, outType.getRank()); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, outType, ValueRange(builtinTensors), posDim); return success(); } @@ -1225,7 +1231,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "this op should be folded as its `min` and `max` both are none"); } else if (failed(checkNotNone(rewriter, op, minValue))) { - maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); + maxValue = + hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); auto minInfo = getMinValueOfDtype(op, inputElemType, rewriter); if (failed(minInfo)) { return rewriter.notifyMatchFailure( @@ -1233,7 +1240,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } minValue = *minInfo; } else if (failed(checkNotNone(rewriter, op, maxValue))) { - minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); + minValue = + hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); auto maxInfo = getMaxValueOfDtype(op, inputElemType, rewriter); if (failed(maxInfo)) { return rewriter.notifyMatchFailure( @@ -1241,10 +1249,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } maxValue = *maxInfo; } else { - minValue = mhlo::scalarToMhloTensor(rewriter, op, minValue, inputElemType); - maxValue = mhlo::scalarToMhloTensor(rewriter, op, maxValue, inputElemType); + minValue = + hlo::scalarToStablehloTensor(rewriter, op, minValue, inputElemType); + maxValue = + hlo::scalarToStablehloTensor(rewriter, op, maxValue, inputElemType); } - rewriter.replaceOpWithNewOp(op, minValue, input, maxValue); + rewriter.replaceOpWithNewOp(op, minValue, input, + maxValue); return success(); } @@ -1266,24 +1277,27 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: only int or float dtype supported"); } - Value start = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStart(), dtype); - Value end = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getEnd(), dtype); - Value step = mhlo::scalarToMhloTensor(rewriter, op, adaptor.getStep(), dtype); + Value start = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStart(), dtype); + Value end = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getEnd(), dtype); + Value step = + hlo::scalarToStablehloTensor(rewriter, op, adaptor.getStep(), dtype); // Get length of the 1-d output tensor - Value subOut = rewriter.create(loc, end, start); - Value divOut = rewriter.create(loc, subOut, step); + Value subOut = rewriter.create(loc, end, start); + Value divOut = rewriter.create(loc, subOut, step); - Value resultLength = rewriter.create( + Value resultLength = rewriter.create( loc, RankedTensorType::get({1}, dtype), divOut); if (dtype.isa()) { - resultLength = rewriter.create(loc, resultLength); - resultLength = rewriter.create( + resultLength = rewriter.create(loc, resultLength); + resultLength = rewriter.create( loc, RankedTensorType::get({1}, rewriter.getI64Type()), resultLength); } Value window = - rewriter.create(loc, outType, resultLength, 0); + rewriter.create(loc, outType, resultLength, 0); DenseIntElementsAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); @@ -1298,9 +1312,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - auto outType = this->getTypeConverter() - ->convertType(op.getType()) - .cast(); + auto outType = + this->getTypeConverter()->convertType(op.getType()).cast(); if (!outType) { return op.emitError("only tensor type is supported"); } @@ -1320,26 +1333,27 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value negHalf = chlo::getConstantLike(rewriter, loc, -0.5, input); // Compute - Value kBeta0 = rewriter.create(loc, outType, kAlpha, cstAlpha0); - Value kBeta = rewriter.create(loc, outType, kBeta0, half); - Value erfArg = - rewriter.create(loc, outType, kAlpha, adaptor.getSelf()); + Value kBeta0 = + rewriter.create(loc, outType, kAlpha, cstAlpha0); + Value kBeta = rewriter.create(loc, outType, kBeta0, half); + Value erfArg = rewriter.create(loc, outType, kAlpha, + adaptor.getSelf()); Value erf = rewriter.create(loc, outType, erfArg); - Value erfAdd = rewriter.create(loc, outType, erf, one); - Value cdf = rewriter.create(loc, outType, erfAdd, half); - Value inputSquared = rewriter.create( + Value erfAdd = rewriter.create(loc, outType, erf, one); + Value cdf = rewriter.create(loc, outType, erfAdd, half); + Value inputSquared = rewriter.create( loc, outType, adaptor.getSelf(), adaptor.getSelf()); Value negHalfInputSquared = - rewriter.create(loc, outType, inputSquared, negHalf); + rewriter.create(loc, outType, inputSquared, negHalf); Value expRes = - rewriter.create(loc, outType, negHalfInputSquared); - Value pdf = rewriter.create(loc, outType, kBeta, expRes); + rewriter.create(loc, outType, negHalfInputSquared); + Value pdf = rewriter.create(loc, outType, kBeta, expRes); Value pdfTimesInput = - rewriter.create(loc, outType, pdf, adaptor.getSelf()); + rewriter.create(loc, outType, pdf, adaptor.getSelf()); Value pdfTimesInputAddCdf = - rewriter.create(loc, outType, pdfTimesInput, cdf); - rewriter.replaceOpWithNewOp(op, outType, adaptor.getGradOutput(), - pdfTimesInputAddCdf); + rewriter.create(loc, outType, pdfTimesInput, cdf); + rewriter.replaceOpWithNewOp( + op, outType, adaptor.getGradOutput(), pdfTimesInputAddCdf); return success(); } @@ -1366,9 +1380,9 @@ class ConvertRuntimeAssertOp : public OpConversionPattern { }; } // namespace -void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( +void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToMhloOptions &options) { + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); @@ -1376,23 +1390,24 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); -#define INSERT_UNARY_PATTERN(AtenOp, MhloOp) \ +#define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context) - INSERT_UNARY_PATTERN(AtenCloneOp, mhlo::CopyOp); - INSERT_UNARY_PATTERN(AtenNegOp, mhlo::NegOp); - INSERT_UNARY_PATTERN(AtenLogicalNotOp, mhlo::NotOp); - INSERT_UNARY_PATTERN(AtenBitwiseNotOp, mhlo::NotOp); + patterns.add>(typeConverter, context) + INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); + INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); + INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); + INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); #undef INSERT_UNARY_PATTERN -#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, MhloOp) \ +#define INSERT_UNARY_FPONLY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, context) - INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, mhlo::LogOp); - INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, mhlo::ExpOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, mhlo::SqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, mhlo::RsqrtOp); - INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, mhlo::LogisticOp); + patterns.add>(typeConverter, \ + context) + INSERT_UNARY_FPONLY_PATTERN(AtenLogOp, stablehlo::LogOp); + INSERT_UNARY_FPONLY_PATTERN(AtenExpOp, stablehlo::ExpOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSqrtOp, stablehlo::SqrtOp); + INSERT_UNARY_FPONLY_PATTERN(AtenRsqrtOp, stablehlo::RsqrtOp); + INSERT_UNARY_FPONLY_PATTERN(AtenSigmoidOp, stablehlo::LogisticOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ @@ -1482,10 +1497,10 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenWhereSelfOp); #undef INSERT_ATENOP_PATTERN -#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, MhloOp) \ +#define INSERT_BINARY_BROADCAST_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context) + patterns.add>( \ + typeConverter, context) INSERT_BINARY_BROADCAST_PATTERN(AtenMaximumOp, chlo::BroadcastMaxOp); INSERT_BINARY_BROADCAST_PATTERN(AtenMinimumOp, chlo::BroadcastMinOp); INSERT_BINARY_BROADCAST_PATTERN(Aten__And__TensorOp, chlo::BroadcastAndOp); diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt new file mode 100644 index 000000000000..237512980562 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -0,0 +1,29 @@ +add_mlir_conversion_library(TorchMLIRTorchToStablehlo + TorchToStablehlo.cpp + StablehloLegalizeUtils.cpp + Basic.cpp + Gather.cpp + Linear.cpp + ViewLike.cpp + Reduction.cpp + Pooling.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_COMPONENTS + Core + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRBufferTransforms + StablehloOps + TorchMLIRTorchDialect + TorchMLIRConversionUtils +) + +torch_mlir_target_includes(TorchMLIRTorchToStablehlo) diff --git a/lib/Conversion/TorchToMhlo/Gather.cpp b/lib/Conversion/TorchToStablehlo/Gather.cpp similarity index 87% rename from lib/Conversion/TorchToMhlo/Gather.cpp rename to lib/Conversion/TorchToStablehlo/Gather.cpp index 8d7a3f5c0457..4373327036c7 100644 --- a/lib/Conversion/TorchToMhlo/Gather.cpp +++ b/lib/Conversion/TorchToStablehlo/Gather.cpp @@ -7,14 +7,15 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./MhloLegalizeUtils.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -24,7 +25,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_mhlo; +using namespace mlir::torch::torch_to_stablehlo; namespace { Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, @@ -69,7 +70,7 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, SmallVector startIndexMap(1, axis); // indexVecDim int64_t indexVecDim = indicesRank; - auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( + auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/offsetDims, /*collapsedSliceDims=*/collapsedSliceDims, @@ -91,17 +92,18 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op, auto outputTy = RankedTensorType::get(outputShape, inputRankTy.getElementType()); return rewriter - .create(loc, outputTy, input, indices, - sliceSizesTensor, dimsAttr) + .create(loc, outputTy, input, indices, + sliceSizesTensor, dimsAttr) .getResult(); } } // namespace -// Ref: https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html +// Ref: +// https://pytorch.org/docs/stable/generated/torch.nn.functional.embedding.html // padding_idx (int, optional) -// – If specified, the entries at padding_idx do not contribute to the gradient; -// therefore, the embedding vector at padding_idx is not updated during training, -// i.e. it remains as a fixed “pad”. +// – If specified, the entries at padding_idx do not contribute to the +// gradient; therefore, the embedding vector at padding_idx is not updated +// during training, i.e. it remains as a fixed “pad”. // scale_grad_by_freq (boolean, optional) // – If given, this will scale gradients by the inverse of frequency of the // words in the mini-batch. Default False. @@ -139,7 +141,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value output = gatherTensorAlongSingleAxis( rewriter, op, weight, adaptor.getIndices(), 0, options.dimSizeIndexBits); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); @@ -161,7 +163,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value output = gatherTensorAlongSingleAxis( rewriter, op, self, adaptor.getIndex(), dim, options.dimSizeIndexBits); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output); return success(); @@ -200,7 +202,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto options = getOptions(); auto indexShapeInfo = - mhlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); + hlo::getDimSizesOfTensor(rewriter, op, index, options.dimSizeIndexBits); if (failed(indexShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dim sizes of `index` param"); @@ -223,15 +225,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector toConcat; for (int64_t i = 0; i < inputType.getRank(); ++i) { if (i == dim) { - toConcat.push_back(rewriter.create( + toConcat.push_back(rewriter.create( loc, toConcatIndexType, index, toConcatIndexShape)); } else { - toConcat.push_back(rewriter.create( + toConcat.push_back(rewriter.create( loc, toConcatIndexType, toConcatIndexShape, rewriter.getI64IntegerAttr(i))); } } - auto gatherIndicies = rewriter.create( + auto gatherIndicies = rewriter.create( loc, toConcat, static_cast(inputType.getRank())); SmallVector sliceSizes(inputType.getRank(), 1); @@ -243,22 +245,22 @@ LogicalResult ConvertAtenOp::matchAndRewrite( startIndexMap.push_back(i); } - auto dimsAttr = mhlo::GatherDimensionNumbersAttr::get( + auto dimsAttr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/{}, /*collapsedSliceDims=*/collapsedDims, /*startIndexMap=*/startIndexMap, /*indexVecDim=*/indexVecDim); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, rewriter.getI64TensorAttr(sliceSizes)); return success(); } -void mlir::torch::torch_to_mhlo::populateGatherOpPatternsAndLegality( +void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToMhloOptions &options) { + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp similarity index 83% rename from lib/Conversion/TorchToMhlo/Linear.cpp rename to lib/Conversion/TorchToStablehlo/Linear.cpp index 8632af4bac68..fbc3d6ee4eb8 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -7,15 +7,16 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./MhloLegalizeUtils.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -25,7 +26,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_mhlo; +using namespace mlir::torch::torch_to_stablehlo; namespace { Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, @@ -33,7 +34,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, ArrayRef broadcastDims) { auto tensorTy = tensor.getType().dyn_cast(); auto loc = op->getLoc(); - Value mhloShape = rewriter.create(loc, dimSizes); + Value stablehloShape = rewriter.create(loc, dimSizes); RankedTensorType outTy = RankedTensorType::get(shape, tensorTy.getElementType()); @@ -43,8 +44,8 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, rewriter.getIntegerType(64)); auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); - auto broadcast = rewriter.create( - loc, outTy, tensor, mhloShape, broadcastAttr); + auto broadcast = rewriter.create( + loc, outTy, tensor, stablehloShape, broadcastAttr); return broadcast; } @@ -52,7 +53,7 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, ArrayRef inpTransDims) { auto inputTy = input.getType().dyn_cast(); auto rank = inputTy.getRank(); - auto transDims = mhlo::toPositiveDims(inpTransDims, rank); + auto transDims = hlo::toPositiveDims(inpTransDims, rank); auto inpShape = inputTy.getShape(); std::vector newShape; newShape.reserve(rank); @@ -66,8 +67,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, auto permuteAttr = DenseIntElementsAttr::get(attrTy, transDims); auto outTy = RankedTensorType::get(newShape, inputTy.getElementType()); - auto result = rewriter.create(op->getLoc(), outTy, input, - permuteAttr); + auto result = rewriter.create(op->getLoc(), outTy, + input, permuteAttr); return result.getResult(); } @@ -119,10 +120,12 @@ RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, } // set result dimensions - if (lhsResultDim < static_cast(lhsShape.size()) && lhsResultDim >= 0) { + if (lhsResultDim < static_cast(lhsShape.size()) && + lhsResultDim >= 0) { outShape.push_back(lhsShape[lhsResultDim]); } - if (rhsResultDim < static_cast(rhsShape.size()) && rhsResultDim >= 0) { + if (rhsResultDim < static_cast(rhsShape.size()) && + rhsResultDim >= 0) { outShape.push_back(rhsShape[rhsResultDim]); } return RankedTensorType::get(outShape, lhsTy.getElementType()); @@ -151,10 +154,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(rhsShape.begin(), rhsShape.begin() + leadingRank); newShape.insert(newShape.end(), lhsShape.begin(), lhsShape.end()); - auto newDimSizes = *mhlo::getDimSizesOfTensor( - rewriter, op, rhs, leadingDims, dimSizeIndexBits); + auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, rhs, leadingDims, + dimSizeIndexBits); auto lhsDimSizes = - *mhlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); + *hlo::getDimSizesOfTensor(rewriter, op, lhs, dimSizeIndexBits); newDimSizes.insert(newDimSizes.end(), lhsDimSizes.begin(), lhsDimSizes.end()); lhs = getBroadcastTensor(rewriter, op, lhs, newShape, newDimSizes, @@ -163,10 +166,10 @@ void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, std::vector newShape(lhsShape.begin(), lhsShape.begin() + leadingRank); newShape.insert(newShape.end(), rhsShape.begin(), rhsShape.end()); - auto newDimSizes = *mhlo::getDimSizesOfTensor( - rewriter, op, lhs, leadingDims, dimSizeIndexBits); + auto newDimSizes = *hlo::getDimSizesOfTensor(rewriter, op, lhs, leadingDims, + dimSizeIndexBits); auto rhsDimSizes = - *mhlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); + *hlo::getDimSizesOfTensor(rewriter, op, rhs, dimSizeIndexBits); newDimSizes.insert(newDimSizes.end(), rhsDimSizes.begin(), rhsDimSizes.end()); rhs = getBroadcastTensor(rewriter, op, rhs, newShape, newDimSizes, @@ -218,8 +221,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { if (lhsRank <= 2 && rhsRank <= 2) { auto tensorType = ConvertAtenOp::getTypeConverter()->convertType(op.getType()); - output = rewriter.create(op->getLoc(), tensorType, lhs, rhs, - nullptr); + output = rewriter.create(op->getLoc(), tensorType, lhs, + rhs, nullptr); return success(); } @@ -253,8 +256,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { lhsContractingDim = nBatchDims; } - mhlo::DotDimensionNumbersAttr dotDimensionNumbers = - mhlo::DotDimensionNumbersAttr::get( + stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = + stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims, @@ -264,8 +267,8 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); output = rewriter - .create(op->getLoc(), outTy, lhs, rhs, - dotDimensionNumbers, nullptr) + .create(op->getLoc(), outTy, lhs, rhs, + dotDimensionNumbers, nullptr) .getResult(); return success(); } @@ -312,7 +315,7 @@ class ConvertAtenMatMulOp : public ConvertAtenMatmulBaseOp { if (!lhsTy || !rhsTy) return op.emitError( - "only ranked tensor types are supported in MHLO matmul"); + "only ranked tensor types are supported in StableHLO matmul"); return success(); } @@ -335,7 +338,7 @@ class ConvertAtenMmOp : public ConvertAtenMatmulBaseOp { if (!lhsTy || !rhsTy) return op.emitError( - "only ranked tensor types are supported in MHLO matmul"); + "only ranked tensor types are supported in StableHLO matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -371,7 +374,7 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { if (!lhsTy || !rhsTy) return op.emitError( - "only ranked tensor types are supported in MHLO matmul"); + "only ranked tensor types are supported in StableHLO matmul"); auto lhsRank = lhsTy.getRank(); auto rhsRank = rhsTy.getRank(); @@ -398,10 +401,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto bias = adaptor.getBias(); auto biasTy = bias.getType(); - // MHLO does not mandate that elementwise op tensors need to be ranked. + // StableHLO does not mandate that elementwise op tensors need to be ranked. if (!biasTy.template isa() && !biasTy.template isa()) - return op.emitError("only ranked tensor types are supported in MHLO " + return op.emitError("only ranked tensor types are supported in StableHLO " "matmul for bias tensor"); // weight.T @@ -427,14 +430,14 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto outTy = castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); - mhlo::DotDimensionNumbersAttr dotDimensionNumbers = - mhlo::DotDimensionNumbersAttr::get( + stablehlo::DotDimensionNumbersAttr dotDimensionNumbers = + stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhsBatchingDimensions=*/batchDims, /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - Value matmulOutput = rewriter.create( + Value matmulOutput = rewriter.create( op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); Value matmulPlusBias = matmulOutput; @@ -464,7 +467,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { auto weightElemTy = weightTy.getElementType(); auto rank = weightTy.getRank(); const auto &options = getOptions(); - SmallVector weightShapeVec = *mhlo::getDimSizesOfTensor( + SmallVector weightShapeVec = *hlo::getDimSizesOfTensor( rewriter, op, weight, options.dimSizeIndexBits); auto weightShape = weightTy.getShape(); SmallVector weightShapeInt(rank); @@ -488,7 +491,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { } Value weightShapeTensor = rewriter.create( op->getLoc(), weightShapeVec); - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), weight, weightShapeTensor); @@ -497,7 +500,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { for (int64_t i = 0; i <= rank; i++) transposeDims[i] = i; std::swap(transposeDims[1], transposeDims[0]); - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), weight, rewriter.getI64TensorAttr(transposeDims)); // 3. [IC//G, G, OC, H, W, ...] => [IC//G, G*OC, H, W, ...] @@ -509,7 +512,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { weightShapeVec[1] = OCMulGValue; weightShapeTensor = rewriter.create( op->getLoc(), weightShapeVec); - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), RankedTensorType::get(weightShapeInt, weightElemTy), weight, weightShapeTensor); return weight; @@ -544,25 +547,27 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { } // Prepare for transposed convolution - SmallVector mhloStrideVec(nSpatialDims, 1); - DenseIntElementsAttr mhloStride = rewriter.getI64TensorAttr(mhloStrideVec); - SmallVector mhloPaddingVec(nSpatialDims * 2, 0); + SmallVector stablehloStrideVec(nSpatialDims, 1); + DenseIntElementsAttr stablehloStride = + rewriter.getI64TensorAttr(stablehloStrideVec); + SmallVector stablehloPaddingVec(nSpatialDims * 2, 0); for (int i = 0; i < nSpatialDims; ++i) { int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; - mhloPaddingVec[i * 2] = padInt; - mhloPaddingVec[i * 2 + 1] = padInt; + stablehloPaddingVec[i * 2] = padInt; + stablehloPaddingVec[i * 2 + 1] = padInt; } - DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( + DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( RankedTensorType::get({nSpatialDims, 2}, rewriter.getI64Type()), - mhloPaddingVec); - SmallVector mhloLhsDilationVec(nSpatialDims); - std::copy(stride.begin(), stride.end(), mhloLhsDilationVec.begin()); - DenseIntElementsAttr mhloLhsDilation = - rewriter.getI64TensorAttr(mhloLhsDilationVec); - SmallVector mhloRhsDilationVec(nSpatialDims); - std::copy(dilation.begin(), dilation.end(), mhloRhsDilationVec.begin()); - DenseIntElementsAttr mhloRhsDilation = - rewriter.getI64TensorAttr(mhloRhsDilationVec); + stablehloPaddingVec); + SmallVector stablehloLhsDilationVec(nSpatialDims); + std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin()); + DenseIntElementsAttr stablehloLhsDilation = + rewriter.getI64TensorAttr(stablehloLhsDilationVec); + SmallVector stablehloRhsDilationVec(nSpatialDims); + std::copy(dilation.begin(), dilation.end(), + stablehloRhsDilationVec.begin()); + DenseIntElementsAttr stablehloRhsDilation = + rewriter.getI64TensorAttr(stablehloRhsDilationVec); DenseElementsAttr windowReversal; ArrayAttr precisionConfig; @@ -571,8 +576,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { for (int i = 0; i < nSpatialDims; ++i) { spatialDims.push_back(i + 2); } - mhlo::ConvDimensionNumbersAttr dimensionNumbers = - mhlo::ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr dimensionNumbers = + stablehlo::ConvDimensionNumbersAttr::get( /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*inputFeatureDimension=*/1, /*inputSpatialDimensions=*/spatialDims, @@ -583,17 +588,18 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { /*outputSpatialDimensions=*/spatialDims); // Reverse and transpose weight - weight = rewriter.create( + weight = rewriter.create( op->getLoc(), weight, rewriter.getI64TensorAttr(spatialDims)); if (groups != 1) { weight = reshapeConvWeight(rewriter, op, weight, groups); } // Create transposed convolution - auto transposedConvOp = rewriter.create( - op->getLoc(), convOutTy, input, weight, mhloStride, mhloPadding, - mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, - static_cast(groups), 1, precisionConfig); + auto transposedConvOp = rewriter.create( + op->getLoc(), convOutTy, input, weight, stablehloStride, + stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, + windowReversal, dimensionNumbers, static_cast(groups), 1, + precisionConfig); // Handle output padding if (!needHandleOutputPadding) { @@ -605,8 +611,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { std::copy(outputPadding.begin(), outputPadding.end(), edgePaddingHighVec.begin() + 2); Value paddingValue = - mhlo::getConstTensor(rewriter, op, {0.0}, {}).value(); - paddingValue = mhlo::promoteType(rewriter, paddingValue, inputTy); + hlo::getConstTensor(rewriter, op, {0.0}, {}).value(); + paddingValue = hlo::promoteType(rewriter, paddingValue, inputTy); mlir::DenseIntElementsAttr edgePaddingLow = rewriter.getI64VectorAttr(edgePaddingLowVec); mlir::DenseIntElementsAttr edgePaddingHigh = @@ -614,7 +620,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { mlir::DenseIntElementsAttr interiorPadding = rewriter.getI64VectorAttr(interiorPaddingVec); - auto paddedOutput = rewriter.create( + auto paddedOutput = rewriter.create( op->getLoc(), outType, transposedConvOp, paddingValue, edgePaddingLow, edgePaddingHigh, interiorPadding); @@ -628,22 +634,22 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { ArrayRef dilation, int64_t groups) const { int64_t nDims = outType.getRank(); - // Get mhlo::ConvolutionOp attributes - DenseIntElementsAttr mhloWindowStride = DenseIntElementsAttr::get( + // Get stablehlo::ConvolutionOp attributes + DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stride.size())}, rewriter.getI64Type()), stride); - std::vector mhloPaddingVec; + std::vector stablehloPaddingVec; for (size_t i = 0; i < padding.size(); i++) { - mhloPaddingVec.emplace_back(padding[i]); - mhloPaddingVec.emplace_back(padding[i]); + stablehloPaddingVec.emplace_back(padding[i]); + stablehloPaddingVec.emplace_back(padding[i]); } - DenseIntElementsAttr mhloPadding = DenseIntElementsAttr::get( + DenseIntElementsAttr stablehloPadding = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(padding.size()), static_cast(2)}, rewriter.getI64Type()), - mhloPaddingVec); - DenseIntElementsAttr mhloRhsDilation = DenseIntElementsAttr::get( + stablehloPaddingVec); + DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(dilation.size())}, rewriter.getI64Type()), dilation); @@ -651,8 +657,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { for (int64_t i = 2; i < nDims; i++) { spatialDimensions.emplace_back(i); } - mhlo::ConvDimensionNumbersAttr dimensionNumbers = - mhlo::ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr dimensionNumbers = + stablehlo::ConvDimensionNumbersAttr::get( /*context=*/rewriter.getContext(), /*inputBatchDimension=*/0, /*inputFeatureDimension=*/1, /*inputSpatialDimensions=*/spatialDimensions, @@ -662,17 +668,18 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { /*outputBatchDimension=*/0, /*outputFeatureDimension=*/1, /*outputSpatialDimensions=*/spatialDimensions); - // mhlo::ConvolutionOp's optional attributes, leave them as default - DenseIntElementsAttr mhloLhsDilation; + // stablehlo::ConvolutionOp's optional attributes, leave them as default + DenseIntElementsAttr stablehloLhsDilation; DenseElementsAttr windowReversal; ArrayAttr precisionConfig; - auto mhloConvOp = rewriter.create( - op->getLoc(), outType, input, weight, mhloWindowStride, mhloPadding, - mhloLhsDilation, mhloRhsDilation, windowReversal, dimensionNumbers, - static_cast(groups), 1, precisionConfig); + auto stablehloConvOp = rewriter.create( + op->getLoc(), outType, input, weight, stablehloWindowStride, + stablehloPadding, stablehloLhsDilation, stablehloRhsDilation, + windowReversal, dimensionNumbers, static_cast(groups), 1, + precisionConfig); - return mhloConvOp.getResult(); + return stablehloConvOp.getResult(); } LogicalResult @@ -754,21 +761,22 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { } } - Value mhloConvResult; + Value stablehloConvResult; if (transposed) { - mhloConvResult = convertTransposedConv( + stablehloConvResult = convertTransposedConv( op, rewriter, outTy, input, weight, stride, padding, dilation, outputPadding, groups, needHandleOutputPadding); } else { - mhloConvResult = convertNormalConv(op, rewriter, outTy, input, weight, - stride, padding, dilation, groups); + stablehloConvResult = + convertNormalConv(op, rewriter, outTy, input, weight, stride, padding, + dilation, groups); } auto bias = adaptor.getBias(); // No bias provided if (failed(checkNotNone(rewriter, op, op.getBias()))) { - rewriter.replaceOp(op, mhloConvResult); + rewriter.replaceOp(op, stablehloConvResult); return success(); } @@ -790,21 +798,21 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { llvm::to_vector<4>(llvm::seq(-nSpatialDims, 0)); const auto &options = getOptions(); - bias = *mhlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, - options.dimSizeIndexBits); - bias = mhlo::promoteType(rewriter, bias, outTy); + bias = *hlo::unsqueezeTensor(rewriter, op, bias, inputUnsqzDims, + options.dimSizeIndexBits); + bias = hlo::promoteType(rewriter, bias, outTy); DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp(op, outTy, mhloConvResult, - bias, bcastDimensions); + rewriter.replaceOpWithNewOp( + op, outTy, stablehloConvResult, bias, bcastDimensions); return success(); } }; } // namespace -void mlir::torch::torch_to_mhlo::populateLinearOpPatternsAndLegality( +void mlir::torch::torch_to_stablehlo::populateLinearOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToMhloOptions &options) { + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_MATMUL_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Conversion/TorchToMhlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp similarity index 72% rename from lib/Conversion/TorchToMhlo/Pooling.cpp rename to lib/Conversion/TorchToStablehlo/Pooling.cpp index 6693bd0971e8..90044cc8b81e 100644 --- a/lib/Conversion/TorchToMhlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -7,15 +7,16 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./MhloLegalizeUtils.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -28,7 +29,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_mhlo; +using namespace mlir::torch::torch_to_stablehlo; static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { @@ -40,14 +41,14 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } @@ -58,15 +59,15 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } op->emitError("unimplemented lowering in AtenPoolingOp"); @@ -116,42 +117,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input - SmallVector mhloStride(inputRank, 1); - SmallVector mhloDilation(inputRank, 1); - SmallVector mhloKernelSize(inputRank, 1); - SmallVector mhloPadding(inputRank * 2, 0); + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), - mhloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + stablehloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), - mhloKernelSize.begin() + inputRank - 2); + stablehloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - mhloPadding[mhloPadding.size() - 4] = padding[0]; - mhloPadding[mhloPadding.size() - 3] = padding[0]; - mhloPadding[mhloPadding.size() - 2] = padding[1]; - mhloPadding[mhloPadding.size() - 1] = padding[1]; + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloKernelSize.size())}, + RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), - mhloKernelSize); + stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloStride.size())}, + RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), - mhloStride); + stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloDilation.size())}, + RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), - mhloDilation); + stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - mhloPadding); - auto reduceWindowOp = rewriter.create( + stablehloPadding); + auto reduceWindowOp = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -168,8 +170,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); Value result = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), result); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), result); } rewriter.replaceOp(op, reduceWindowOp.getResults()); @@ -221,45 +223,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input - SmallVector mhloStride(inputRank, 1); - SmallVector mhloDilation(inputRank, 1); - SmallVector mhloKernelSize(inputRank, 1); - SmallVector mhloPadding(inputRank * 2, 0); + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); std::copy(dilation.begin(), dilation.end(), - mhloDilation.begin() + inputRank - 2); - std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + stablehloDilation.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), - mhloKernelSize.begin() + inputRank - 2); + stablehloKernelSize.begin() + inputRank - 2); Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - mhloPadding[mhloPadding.size() - 4] = padding[0]; - mhloPadding[mhloPadding.size() - 3] = padding[0]; - mhloPadding[mhloPadding.size() - 2] = padding[1]; - mhloPadding[mhloPadding.size() - 1] = padding[1]; + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloKernelSize.size())}, + RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), - mhloKernelSize); + stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloStride.size())}, + RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), - mhloStride); + stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloDilation.size())}, + RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), - mhloDilation); + stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - mhloPadding); + stablehloPadding); const auto &options = getOptions(); auto inputShapeInfo = - mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -289,7 +292,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto initIndexTensor = rewriter - .create( + .create( op->getLoc(), RankedTensorType::get(initIndexShapeForType, rewriter.getI64Type()), @@ -298,15 +301,15 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto indexTensor = rewriter - .create( + .create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getI64Type()), initIndexTensor, inputShapeTensor) .getResult(); - Value initIdx = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + Value initIdx = hlo::getConstTensor(rewriter, op, {0}, {}).value(); - auto reduceWindowOp = rewriter.create( + auto reduceWindowOp = rewriter.create( op->getLoc(), mlir::TypeRange{outValTy, outIdxTy}, mlir::ValueRange{input, indexTensor}, mlir::ValueRange{initVal, initIdx}, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -326,43 +329,43 @@ LogicalResult ConvertAtenOp::matchAndRewrite( auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); - mhlo::ComparisonTypeAttr compareTypeAttr; + stablehlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { - compareTypeAttr = mhlo::ComparisonTypeAttr::get( - rewriter.getContext(), mhlo::ComparisonType::FLOAT); + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { - compareTypeAttr = mhlo::ComparisonTypeAttr::get( - rewriter.getContext(), mhlo::ComparisonType::SIGNED); + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } - mhlo::ComparisonDirectionAttr compareGeDirectionAttr = - mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), - mhlo::ComparisonDirection::GE); - mhlo::ComparisonDirectionAttr compareEqDirectionAttr = - mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), - mhlo::ComparisonDirection::EQ); + stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( + Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); - Value retValResult = rewriter.create( + Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // Get smaller index if compared values are equal. - Value compareEqResult = rewriter.create( + Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = - rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); - Value idxWithGeVal = rewriter.create( + Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, + *secondIdxArg); + Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( + Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - rewriter.create( + rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } @@ -419,41 +422,42 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // prepend 1 to kernelSize, stride, dilation until they are of same rank as // input - SmallVector mhloStride(inputRank, 1); - SmallVector mhloDilation(inputRank, 1); - SmallVector mhloKernelSize(inputRank, 1); - SmallVector mhloPadding(inputRank * 2, 0); + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloKernelSize(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); - std::copy(stride.begin(), stride.end(), mhloStride.begin() + inputRank - 2); + std::copy(stride.begin(), stride.end(), + stablehloStride.begin() + inputRank - 2); std::copy(kernelSize.begin(), kernelSize.end(), - mhloKernelSize.begin() + inputRank - 2); - mhloPadding[mhloPadding.size() - 4] = padding[0]; - mhloPadding[mhloPadding.size() - 3] = padding[0]; - mhloPadding[mhloPadding.size() - 2] = padding[1]; - mhloPadding[mhloPadding.size() - 1] = padding[1]; + stablehloKernelSize.begin() + inputRank - 2); + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloKernelSize.size())}, + RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), - mhloKernelSize); + stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloStride.size())}, + RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), - mhloStride); + stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloDilation.size())}, + RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), - mhloDilation); + stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - mhloPadding); + stablehloPadding); - auto reduceWindowSum = rewriter.create( + auto reduceWindowSum = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -471,39 +475,39 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { - Value divisor = mhlo::getConstTensor( + Value divisor = hlo::getConstTensor( rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) .value(); - divisor = mhlo::promoteType(rewriter, divisor, outTy); + divisor = hlo::promoteType(rewriter, divisor, outTy); DenseIntElementsAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); return success(); } - // Use another mhlo.ReduceWindowOp to get the divisor + // Use another stablehlo.ReduceWindowOp to get the divisor Value windowSizeConst = - mhlo::getConstTensor(rewriter, op, {1.0}, {}).value(); - windowSizeConst = mhlo::promoteType(rewriter, windowSizeConst, outTy); + hlo::getConstTensor(rewriter, op, {1.0}, {}).value(); + windowSizeConst = hlo::promoteType(rewriter, windowSizeConst, outTy); const auto &options = getOptions(); auto inputShapeVec = - *mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); - windowSizeConst = rewriter.create( + windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - auto reduceWindowSize = rewriter.create( + auto reduceWindowSize = rewriter.create( op->getLoc(), RankedTensorType::get(outShape, inputElemTy), windowSizeConst, zero, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -522,11 +526,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.setInsertionPointToStart(&sizeBlock); Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); } @@ -560,33 +564,33 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - SmallVector mhloKernelSize(inputRank, 1); - mhloKernelSize[dim] = inputShape[dim]; - SmallVector mhloStride(inputRank, 1); - SmallVector mhloDilation(inputRank, 1); - SmallVector mhloPadding(inputRank * 2, 0); - mhloPadding[dim * 2] = inputShape[dim] - 1; + SmallVector stablehloKernelSize(inputRank, 1); + stablehloKernelSize[dim] = inputShape[dim]; + SmallVector stablehloStride(inputRank, 1); + SmallVector stablehloDilation(inputRank, 1); + SmallVector stablehloPadding(inputRank * 2, 0); + stablehloPadding[dim * 2] = inputShape[dim] - 1; DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloKernelSize.size())}, + RankedTensorType::get({static_cast(stablehloKernelSize.size())}, rewriter.getI64Type()), - mhloKernelSize); + stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloStride.size())}, + RankedTensorType::get({static_cast(stablehloStride.size())}, rewriter.getI64Type()), - mhloStride); + stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(mhloDilation.size())}, + RankedTensorType::get({static_cast(stablehloDilation.size())}, rewriter.getI64Type()), - mhloDilation); + stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, rewriter.getI64Type()), - mhloPadding); + stablehloPadding); - auto reduceWindowSum = rewriter.create( + auto reduceWindowSum = rewriter.create( op->getLoc(), outTy, input, initVal, windowDimensions, windowStrides, baseDilations, windowDilations, pad); @@ -604,17 +608,17 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.setInsertionPointToStart(&sumBlock); Value sumResult = - rewriter.create(op->getLoc(), *firstArg, *secondArg); - rewriter.create(op->getLoc(), sumResult); + rewriter.create(op->getLoc(), *firstArg, *secondArg); + rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOp(op, reduceWindowSum.getResults()); return success(); } -void mlir::torch::torch_to_mhlo::populatePoolingOpPatternsAndLegality( +void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToMhloOptions &options) { + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add>(typeConverter, context, options); diff --git a/lib/Conversion/TorchToStablehlo/PopulatePatterns.h b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h new file mode 100644 index 000000000000..b6322efd6897 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/PopulatePatterns.h @@ -0,0 +1,69 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H +#define TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H + +#include "mlir/Transforms/DialectConversion.h" + +namespace mlir { +namespace torch { +namespace torch_to_stablehlo { + +struct TorchToStablehloOptions { + bool enableStaticShape = false; + size_t dimSizeIndexBits = 64; +}; + +template +class ConvertAtenOp : public OpConversionPattern { +public: + using OpAdaptor = typename AtenOpT::Adaptor; + ConvertAtenOp(TypeConverter &typeConverter, MLIRContext *context, + const TorchToStablehloOptions &options) + : OpConversionPattern(typeConverter, context) { + this->options = options; + } + LogicalResult + matchAndRewrite(AtenOpT op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + return rewriter.notifyMatchFailure(op, "haven't been implemented"); + } + const TorchToStablehloOptions &getOptions() const { return options; } + +private: + TorchToStablehloOptions options; +}; + +void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter, + RewritePatternSet &patterns, + ConversionTarget &target, + const TorchToStablehloOptions &options); +void populateViewLikeOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options); +void populateGatherOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options); +void populateReductionOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options); +void populateLinearOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options); + +void populatePoolingOpPatternsAndLegality( + TypeConverter &typeConverter, RewritePatternSet &patterns, + ConversionTarget &target, const TorchToStablehloOptions &options); + +} // namespace torch_to_stablehlo +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_LIB_CONVERSION_TORCHTOSTABLEHLO_POPULATEPATTERNS_H diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp similarity index 74% rename from lib/Conversion/TorchToMhlo/Reduction.cpp rename to lib/Conversion/TorchToStablehlo/Reduction.cpp index 1196933f1ea3..eb4e11116c71 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -7,14 +7,15 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./MhloLegalizeUtils.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -25,7 +26,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; -using namespace mlir::torch::torch_to_mhlo; +using namespace mlir::torch::torch_to_stablehlo; static Value createInitialValueForReduceOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { @@ -36,14 +37,14 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, constType, {APFloat::getZero( elementTy.cast().getFloatSemantics(), /*negative=*/false)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getZero(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } @@ -53,15 +54,15 @@ static Value createInitialValueForReduceOp(Operation *op, Type elementTy, constType, {APFloat::getLargest( elementTy.cast().getFloatSemantics(), /*negative=*/true)}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } else if (elementTy.isa() && elementTy.getIntOrFloatBitWidth() != 8) { auto constAttr = DenseElementsAttr::get( constType, {APInt::getSignedMinValue(elementTy.getIntOrFloatBitWidth())}); - return rewriter.create(op->getLoc(), constType, - constAttr); + return rewriter.create(op->getLoc(), constType, + constAttr); } } @@ -90,9 +91,9 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, return std::nullopt; Value initIndex; if (dimSizeIndexBits == 32) { - initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } else { - initIndex = mhlo::getConstTensor(rewriter, op, {0}, {}).value(); + initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } DenseIntElementsAttr dimensions = DenseIntElementsAttr::get( @@ -100,13 +101,13 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); - auto indexTensor = rewriter.create( + auto indexTensor = rewriter.create( op->getLoc(), RankedTensorType::get(inputShape, rewriter.getIntegerType(dimSizeIndexBits)), inputShapeTensor, static_cast(dim)); - auto mhloReduceOp = rewriter.create( + auto stablehloReduceOp = rewriter.create( op->getLoc(), ValueRange{input, indexTensor}, ValueRange{ initValue, @@ -114,7 +115,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, }, dimensions); - Block &block = mhloReduceOp.getBody().emplaceBlock(); + Block &block = stablehloReduceOp.getBody().emplaceBlock(); // Add block arguments auto blockValArgumentType = @@ -133,46 +134,46 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, auto *secondValArg = std::next(firstIdxArg); auto *secondIdxArg = std::next(secondValArg); - mhlo::ComparisonTypeAttr compareTypeAttr; + stablehlo::ComparisonTypeAttr compareTypeAttr; if (inputTy.getElementType().isa()) { - compareTypeAttr = mhlo::ComparisonTypeAttr::get( - rewriter.getContext(), mhlo::ComparisonType::FLOAT); + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::FLOAT); } else if (inputTy.getElementType().isa()) { - compareTypeAttr = mhlo::ComparisonTypeAttr::get( - rewriter.getContext(), mhlo::ComparisonType::SIGNED); + compareTypeAttr = stablehlo::ComparisonTypeAttr::get( + rewriter.getContext(), stablehlo::ComparisonType::SIGNED); } - mhlo::ComparisonDirectionAttr compareGeDirectionAttr = - mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), - mhlo::ComparisonDirection::GE); - mhlo::ComparisonDirectionAttr compareEqDirectionAttr = - mhlo::ComparisonDirectionAttr::get(rewriter.getContext(), - mhlo::ComparisonDirection::EQ); + stablehlo::ComparisonDirectionAttr compareGeDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::GE); + stablehlo::ComparisonDirectionAttr compareEqDirectionAttr = + stablehlo::ComparisonDirectionAttr::get( + rewriter.getContext(), stablehlo::ComparisonDirection::EQ); { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value compareGeResult = rewriter.create( + Value compareGeResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareGeDirectionAttr, compareTypeAttr); - Value retValResult = rewriter.create( + Value retValResult = rewriter.create( op->getLoc(), compareGeResult, *firstValArg, *secondValArg); // get smaller index value if compared nums are equal. - Value compareEqResult = rewriter.create( + Value compareEqResult = rewriter.create( op->getLoc(), compareResultType, *firstValArg, *secondValArg, compareEqDirectionAttr, compareTypeAttr); - Value minIdx = - rewriter.create(op->getLoc(), *firstIdxArg, *secondIdxArg); - Value idxWithGeVal = rewriter.create( + Value minIdx = rewriter.create(op->getLoc(), *firstIdxArg, + *secondIdxArg); + Value idxWithGeVal = rewriter.create( op->getLoc(), compareGeResult, *firstIdxArg, *secondIdxArg); - Value retIdxResult = rewriter.create( + Value retIdxResult = rewriter.create( op->getLoc(), compareEqResult, minIdx, idxWithGeVal); - rewriter.create( + rewriter.create( op->getLoc(), mlir::ValueRange{retValResult, retIdxResult}); } - return mhloReduceOp.getResults(); + return stablehloReduceOp.getResults(); } namespace { @@ -196,7 +197,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value input = adaptor.getSelf(); auto inputTy = input.getType().template cast(); if (!inputTy) { - return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); @@ -209,7 +211,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenArgmaxOp to MHLO"); + "AtenArgmaxOp to StableHLO"); } int64_t dim; @@ -228,15 +230,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( const auto &options = getOptions(); auto inputShapeInfo = - mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, - options.dimSizeIndexBits) - .value(); + auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, + dim, options.dimSizeIndexBits) + .value(); if (keepDim) { auto outShapeVec = inputShapeVec; @@ -247,13 +249,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( - op, typeConverter->convertType(op.getType()), mhloReduceResults[1], + rewriter.replaceOpWithNewOp( + op, typeConverter->convertType(op.getType()), stablehloReduceResults[1], outShapeTensor); return success(); } - rewriter.replaceOp(op, mhloReduceResults[1]); + rewriter.replaceOp(op, stablehloReduceResults[1]); return success(); } } // namespace @@ -267,7 +269,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value input = adaptor.getSelf(); auto inputTy = input.getType().template dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -279,7 +282,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxDimOp to MHLO"); + "AtenMaxDimOp to StableHLO"); } RankedTensorType valResultType = getTypeConverter() @@ -308,15 +311,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( const auto &options = getOptions(); auto inputShapeInfo = - mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(inputShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto inputShapeVec = *inputShapeInfo; - auto mhloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, dim, - options.dimSizeIndexBits) - .value(); + auto stablehloReduceResults = getMaxInDim(rewriter, op, input, inputShapeVec, + dim, options.dimSizeIndexBits) + .value(); if (keepDim) { auto outShapeVec = inputShapeVec; @@ -327,15 +330,21 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - auto mhloReduceValueResult = rewriter.create( - op->getLoc(), valResultType, mhloReduceResults[0], outShapeTensor); - auto mhloReduceIndexResult = rewriter.create( - op->getLoc(), idxResultType, mhloReduceResults[1], outShapeTensor); - rewriter.replaceOp(op, {mhloReduceValueResult, mhloReduceIndexResult}); + auto stablehloReduceValueResult = + rewriter.create( + op->getLoc(), valResultType, stablehloReduceResults[0], + outShapeTensor); + auto stablehloReduceIndexResult = + rewriter.create( + op->getLoc(), idxResultType, stablehloReduceResults[1], + outShapeTensor); + rewriter.replaceOp( + op, {stablehloReduceValueResult, stablehloReduceIndexResult}); return success(); } - rewriter.replaceOp(op, {mhloReduceResults[0], mhloReduceResults[1]}); + rewriter.replaceOp(op, + {stablehloReduceResults[0], stablehloReduceResults[1]}); return success(); } } // namespace @@ -352,12 +361,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ->convertType(op.getType()) .template dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); } if (inputTy.getElementType() != outTy.getElementType()) { // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); - input = rewriter.create(op->getLoc(), input, dstElemTy); + input = + rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } auto inputElemTy = inputTy.getElementType(); @@ -370,7 +381,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumOp to MHLO"); + "AtenSumOp to StableHLO"); } SmallVector dims; @@ -379,13 +390,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) return failure(); + if (!initValue) + return failure(); llvm::sort(dims.begin(), dims.end()); - auto mhloReduceOp = rewriter.create( + auto stablehloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Block &block = mhloReduceOp.getBody().emplaceBlock(); + Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); @@ -397,13 +409,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( + Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + rewriter.create(op->getLoc(), addResult); } rewriter.replaceOpWithNewOp(op, outTy, - mhloReduceOp.getResults()); + stablehloReduceOp.getResults()); return success(); } } // namespace @@ -417,7 +429,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value input = adaptor.getSelf(); auto inputTy = input.getType().dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); } auto inputElemTy = inputTy.getElementType(); if (!inputElemTy.isIntOrFloat()) { @@ -429,7 +442,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenMaxOp to MHLO"); + "AtenMaxOp to StableHLO"); } SmallVector dims; @@ -439,12 +452,13 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) return failure(); + if (!initValue) + return failure(); llvm::sort(dims.begin(), dims.end()); - auto mhloReduceOp = rewriter.create( + auto stablehloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Block &block = mhloReduceOp.getBody().emplaceBlock(); + Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); block.addArgument(blockArgumentTy, op->getLoc()); @@ -456,14 +470,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value maxResult = rewriter.create( + Value maxResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), maxResult); + rewriter.create(op->getLoc(), maxResult); } rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), - mhloReduceOp.getResults()); + stablehloReduceOp.getResults()); return success(); } } // namespace @@ -480,12 +494,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ->convertType(op.getType()) .template dyn_cast(); if (!inputTy) { - return rewriter.notifyMatchFailure(op, "only Tensor types supported in MHLO"); + return rewriter.notifyMatchFailure( + op, "only Tensor types supported in StableHLO"); } if (inputTy.getElementType() != outTy.getElementType()) { // Use output element type as computation type. auto dstElemTy = outTy.getElementType(); - input = rewriter.create(op->getLoc(), input, dstElemTy); + input = + rewriter.create(op->getLoc(), input, dstElemTy); inputTy = input.getType().dyn_cast(); } auto inputElemTy = inputTy.getElementType(); @@ -499,7 +515,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( inputElemTy.getIntOrFloatBitWidth() == 8) { return rewriter.notifyMatchFailure( op, "IntegerType with bitwidth 8 unsupported in convertion from " - "AtenSumDimIntListOp to MHLO"); + "AtenSumDimIntListOp to StableHLO"); } SmallVector inputDims; @@ -525,13 +541,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } Value initValue = createInitialValueForReduceOp(op, inputTy.getElementType(), rewriter); - if (!initValue) return failure(); + if (!initValue) + return failure(); llvm::sort(dims.begin(), dims.end()); - auto mhloReduceOp = rewriter.create( + auto stablehloReduceOp = rewriter.create( op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); - Region ®ion = mhloReduceOp.getBody(); + Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -544,15 +561,15 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( { OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - Value addResult = rewriter.create( + Value addResult = rewriter.create( op->getLoc(), blockArgumentTy, *firstArgument, *secondArgument); - rewriter.create(op->getLoc(), addResult); + rewriter.create(op->getLoc(), addResult); } if (keepDim) { const auto &options = getOptions(); - auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, - options.dimSizeIndexBits); + auto outShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -567,26 +584,27 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), - mhloReduceOp.getResult(0), outShapeTensor); + stablehloReduceOp.getResult(0), outShapeTensor); return success(); } rewriter.replaceOpWithNewOp(op, outTy, - mhloReduceOp.getResults()); + stablehloReduceOp.getResults()); return success(); } } // namespace // AtenFrobeniusNormDimOp -// aten.frobenius_norm.dim => mhlo.reduce(calculate square sum along given dims) -// + mhlo.sqrt +// aten.frobenius_norm.dim => stablehlo.reduce(calculate square sum along given +// dims) +// + stablehlo.sqrt namespace { template <> LogicalResult ConvertAtenReductionOp::matchAndRewrite( AtenFrobeniusNormDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { - const TorchToMhloOptions &options = getOptions(); + const TorchToStablehloOptions &options = getOptions(); Value input = adaptor.getSelf(); auto inputType = input.getType().dyn_cast(); @@ -614,7 +632,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } - // Sort the dims in ascending order, making the conversion + // Sort the dims in ascending order, making the conversion // stable with unordered dims. std::sort(dims.begin(), dims.end()); @@ -624,14 +642,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( op, "non-const bool `keepdim` is not supported"); } - auto squareOp = rewriter.create(op->getLoc(), input, input); + auto squareOp = rewriter.create(op->getLoc(), input, input); auto initValue = createInitialValueForReduceOp(op, inputElemType, rewriter); if (!initValue) { return failure(); } - auto reduceOp = rewriter.create( + auto reduceOp = rewriter.create( op->getLoc(), squareOp.getResult(), initValue, rewriter.getI64TensorAttr(dims)); @@ -649,30 +667,32 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( OpBuilder::InsertionGuard guard(rewriter); rewriter.setInsertionPointToStart(&block); - auto addResult = rewriter.create(op->getLoc(), firstArgument, - secondArgument); - rewriter.create(op->getLoc(), addResult.getResult()); + auto addResult = rewriter.create( + op->getLoc(), firstArgument, secondArgument); + rewriter.create(op->getLoc(), addResult.getResult()); } auto output = - rewriter.create(op->getLoc(), reduceOp.getResult(0)); + rewriter.create(op->getLoc(), reduceOp.getResult(0)); if (keepDim) { - auto outShapeInfo = mhlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto outShapeInfo = + hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); if (failed(outShapeInfo)) { return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); } auto outShapeVec = *outShapeInfo; auto one = rewriter.create( - op->getLoc(), rewriter.getIntegerAttr( - rewriter.getIntegerType(options.dimSizeIndexBits), 1)); + op->getLoc(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(options.dimSizeIndexBits), 1)); for (int64_t i : dims) { outShapeVec[i] = one; } auto outShapeTensor = rewriter.create( op->getLoc(), outShapeVec); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), output, outShapeTensor); return success(); @@ -682,9 +702,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } // namespace -void mlir::torch::torch_to_mhlo::populateReductionOpPatternsAndLegality( +void mlir::torch::torch_to_stablehlo::populateReductionOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToMhloOptions &options) { + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATEN_REDUCTION_OP_PATTERN(AtenOp) \ target.addIllegalOp(); \ diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp similarity index 84% rename from lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp rename to lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index b9fb00affc80..dbcfba2ff306 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -7,11 +7,12 @@ // //===----------------------------------------------------------------------===// -#include "./MhloLegalizeUtils.h" -#include "mhlo/IR/hlo_ops.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include @@ -21,27 +22,27 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace mlir { -namespace mhlo { +namespace hlo { // Create a 32-bit float constant operator from a float -Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, - float val) { +Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val) { auto const_type = RankedTensorType::get({}, rewriter.getF32Type()); auto const_attr = DenseElementsAttr::get(const_type, val); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } // Create a 64-bit float constant operator from a double -Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, - double val) { +Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val) { auto const_type = RankedTensorType::get({}, rewriter.getF64Type()); auto const_attr = DenseElementsAttr::get(const_type, val); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -65,8 +66,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, RankedTensorType::get(shape, rewriter.getIntegerType(sizeof(T) * 8)); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -88,8 +89,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, shape, rewriter.getIntegerType(vec[0].getBitWidth())); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -111,8 +112,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_type = RankedTensorType::get(shape, rewriter.getF32Type()); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -133,8 +134,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, auto const_type = RankedTensorType::get(shape, rewriter.getF64Type()); auto const_attr = DenseElementsAttr::get(const_type, vec); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } @@ -169,18 +170,18 @@ Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, T val, Type dtype, llvm::ArrayRef dshape) { auto const_type = RankedTensorType::get(dshape, dtype); auto const_attr = SplatElementsAttr::get(const_type, val); - auto const_op = - rewriter.create(op->getLoc(), const_type, const_attr); + auto const_op = rewriter.create( + op->getLoc(), const_type, const_attr); return const_op.getResult(); } -Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, - Value scalarValue, Type dtype) { +Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value scalarValue, Type dtype) { auto tensor = rewriter.create( op->getLoc(), ArrayRef{scalarValue}); auto dtype_tensor = - rewriter.create(op->getLoc(), tensor, dtype); - return rewriter.create( + rewriter.create(op->getLoc(), tensor, dtype); + return rewriter.create( op->getLoc(), RankedTensorType::get(mlir::ArrayRef{}, dtype), dtype_tensor); } @@ -192,7 +193,8 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { if (in_type.getElementType() != outType.getElementType()) { TensorType promotedType = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - return rewriter.create(op->getLoc(), promotedType, input); + return rewriter.create(op->getLoc(), promotedType, + input); } return input; } @@ -210,8 +212,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, if (in_type.getElementType() != outType.getElementType()) { TensorType promoted_type = in_type.cloneWith(in_type.getShape(), outType.getElementType()); - input = - rewriter.create(op->getLoc(), promoted_type, input); + input = rewriter.create(op->getLoc(), promoted_type, + input); } ArrayRef inShape = in_type.getShape(); @@ -245,8 +247,8 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, RankedTensorType::get({static_cast(bcastDims.size())}, rewriter.getI64Type()), bcastDims); - auto bcast_op = rewriter.create(op->getLoc(), outType, - input, bcast_attr); + auto bcast_op = rewriter.create( + op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); } @@ -348,8 +350,8 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, } auto outTy = RankedTensorType::get(newShape, rankTy.getElementType()); - auto mhloShape = rewriter.create(loc, newDimSizes); - return rewriter.create(loc, outTy, tensor, mhloShape) + auto shape = rewriter.create(loc, newDimSizes); + return rewriter.create(loc, outTy, tensor, shape) .getResult(); } @@ -357,11 +359,11 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType) { auto constAttr = rewriter.getFloatAttr(outType.getElementType(), constant); - auto constTensor = rewriter.create(loc, constAttr); + auto constTensor = rewriter.create(loc, constAttr); return rewriter - .create(loc, outType, constTensor, shape, - rewriter.getI64TensorAttr({})) + .create( + loc, outType, constTensor, shape, rewriter.getI64TensorAttr({})) .getResult(); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h similarity index 79% rename from lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h rename to lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h index dc7daa42d346..6d31d267ac0b 100644 --- a/lib/Conversion/TorchToMhlo/MhloLegalizeUtils.h +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#ifndef TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H -#define TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H +#ifndef TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H +#define TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -18,22 +18,22 @@ #include "mlir/Transforms/DialectConversion.h" namespace mlir { -namespace mhlo { +namespace hlo { using mlir::ConversionPatternRewriter; // Create a 32-bit float constant operator from a float -Value getMhloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, - float val); +Value getStablehloConstTensorSingleF32(PatternRewriter &rewriter, Operation *op, + float val); // Create a 64-bit float constant operator from a double -Value getMhloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, - double val); +Value getStablehloConstTensorSingleF64(PatternRewriter &rewriter, Operation *op, + double val); // Templated function to create a constant op for given type and shape. // T: storage C type. // Default template creates a constant tensor in T. -// To create INT48 MHLO constant, need to pass in llvm::APInt instead. +// To create INT48 StableHLO constant, need to pass in llvm::APInt instead. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, ArrayRef shape); @@ -42,8 +42,8 @@ template Value getSplatConstTensor(ConversionPatternRewriter &rewriter, Operation *op, T val, Type dtype, llvm::ArrayRef dshape); -Value scalarToMhloTensor(ConversionPatternRewriter &rewriter, Operation *op, - Value scalarValue, Type dtype); +Value scalarToStablehloTensor(ConversionPatternRewriter &rewriter, + Operation *op, Value scalarValue, Type dtype); Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); @@ -71,7 +71,7 @@ FailureOr unsqueezeTensor(PatternRewriter &rewriter, Operation *op, Value getConstantOfShape(PatternRewriter &rewriter, Location loc, const APFloat &constant, Value shape, TensorType outType); -} // namespace mhlo +} // namespace hlo } // namespace mlir -#endif // TORCHMLIR_CONVERSION_TORCHTOMHLO_MHLOLEGALIZEUTILS_H +#endif // TORCHMLIR_CONVERSION_TORCHTOSTABLEHLO_STABLEHLOLEGALIZEUTILS_H diff --git a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp similarity index 58% rename from lib/Conversion/TorchToMhlo/TorchToMhlo.cpp rename to lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp index f81afd9ca92b..ba08384846cc 100644 --- a/lib/Conversion/TorchToMhlo/TorchToMhlo.cpp +++ b/lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp @@ -7,17 +7,18 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -30,17 +31,18 @@ using namespace mlir::torch::Torch; namespace { -class ConvertTorchToMhlo : public ConvertTorchToMhloBase { +class ConvertTorchToStablehlo + : public ConvertTorchToStablehloBase { public: - ConvertTorchToMhlo() = default; - ConvertTorchToMhlo(bool enableStaticShape, bool enableI32Index) { + ConvertTorchToStablehlo() = default; + ConvertTorchToStablehlo(bool enableStaticShape, bool enableI32Index) { this->enableStaticShape = enableStaticShape; this->enableI32Index = enableI32Index; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); - registry.insert(); + registry.insert(); registry.insert(); registry.insert(); TorchConversion::getBackendTypeConversionDependentDialects(registry); @@ -48,7 +50,7 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target(*context); - target.addLegalDialect(); TypeConverter typeConverter; @@ -57,20 +59,20 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { RewritePatternSet patterns(context); - torch_to_mhlo::TorchToMhloOptions options{enableStaticShape, - enableI32Index ? 32u : 64u}; - torch_to_mhlo::populateBasicOpPatternsAndLegality(typeConverter, patterns, - target, options); - torch_to_mhlo::populateViewLikeOpPatternsAndLegality( + torch_to_stablehlo::TorchToStablehloOptions options{ + enableStaticShape, enableI32Index ? 32u : 64u}; + torch_to_stablehlo::populateBasicOpPatternsAndLegality( + typeConverter, patterns, target, options); + torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( + typeConverter, patterns, target, options); + torch_to_stablehlo::populateGatherOpPatternsAndLegality( + typeConverter, patterns, target, options); + torch_to_stablehlo::populateReductionOpPatternsAndLegality( + typeConverter, patterns, target, options); + torch_to_stablehlo::populateLinearOpPatternsAndLegality( typeConverter, patterns, target, options); - torch_to_mhlo::populateGatherOpPatternsAndLegality(typeConverter, patterns, - target, options); - torch_to_mhlo::populateReductionOpPatternsAndLegality( + torch_to_stablehlo::populatePoolingOpPatternsAndLegality( typeConverter, patterns, target, options); - torch_to_mhlo::populateLinearOpPatternsAndLegality(typeConverter, patterns, - target, options); - torch_to_mhlo::populatePoolingOpPatternsAndLegality(typeConverter, patterns, - target, options); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) { @@ -82,13 +84,13 @@ class ConvertTorchToMhlo : public ConvertTorchToMhloBase { } // namespace std::unique_ptr> -mlir::torch::createConvertTorchToMhloPass() { - return std::make_unique(false, false); +mlir::torch::createConvertTorchToStablehloPass() { + return std::make_unique(false, false); } std::unique_ptr> -mlir::torch::createConvertTorchToMhloPass(bool enableStaticShape, - bool enableI32Index) { - return std::make_unique(enableStaticShape, - enableI32Index); +mlir::torch::createConvertTorchToStablehloPass(bool enableStaticShape, + bool enableI32Index) { + return std::make_unique(enableStaticShape, + enableI32Index); } diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp similarity index 88% rename from lib/Conversion/TorchToMhlo/ViewLike.cpp rename to lib/Conversion/TorchToStablehlo/ViewLike.cpp index 29284d50e3d9..b6511c384068 100644 --- a/lib/Conversion/TorchToMhlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -7,14 +7,15 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #include "../PassDetail.h" -#include "./MhloLegalizeUtils.h" -#include "./PopulatePatterns.h" -#include "mhlo/IR/hlo_ops.h" +#include "PopulatePatterns.h" +#include "StablehloLegalizeUtils.h" + #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -28,7 +29,7 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; using namespace mlir::torch::TorchConversion; -using namespace mlir::torch::torch_to_mhlo; +using namespace mlir::torch::torch_to_stablehlo; namespace { // A dimension index from torch.dialect might outside the range [0, dimSize]. @@ -100,7 +101,7 @@ Value getDynamicSliceInternal(PatternRewriter &rewriter, Operation *op, auto stridesTensor = rewriter.create(loc, strides).getResult(); - return rewriter.create( + return rewriter.create( loc, outTy, input, startTensor, endTensor, stridesTensor); } @@ -144,7 +145,7 @@ FailureOr getDynamicSlice(PatternRewriter &rewriter, Operation *op, step = rewriter.create(loc, intType, step); } FailureOr> dimSizesInfo = - mhlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); + hlo::getDimSizesOfTensor(rewriter, op, input, dimSizeIndexBits); if (failed(dimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); @@ -179,7 +180,7 @@ class ConvertAtenViewOp : public ConvertAtenOp { auto loc = op.getLoc(); auto newRank = dimSizes.size(); if (newRank == 0 || rankType.getRank() == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), @@ -214,17 +215,18 @@ class ConvertAtenViewOp : public ConvertAtenOp { numel); if (dimSizes.size() == 0) { - rewriter.replaceOpWithNewOp( - op, - OpConversionPattern::getTypeConverter()->convertType( - op.getType()), - adaptor.getSelf()); + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.getSelf()); return success(); } - Value mhloShape = rewriter.create(loc, dimSizes); - Value computedShape = rewriter.create( - loc, mhloShape.getType(), numel, mhloShape); - rewriter.replaceOpWithNewOp( + Value stablehloShape = + rewriter.create(loc, dimSizes); + Value computedShape = rewriter.create( + loc, stablehloShape.getType(), numel, stablehloShape); + rewriter.replaceOpWithNewOp( op, OpConversionPattern::getTypeConverter()->convertType( op.getType()), @@ -315,21 +317,21 @@ LogicalResult ConvertAtenOp::matchAndRewrite( dims.push_back(r); } if (dims.size() == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; - auto mhloShape = + auto stablehloShape = rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); return success(); } @@ -365,20 +367,20 @@ LogicalResult ConvertAtenOp::matchAndRewrite( std::iota(dims.begin(), dims.end(), 0); dims.erase(dims.begin() + dim); if (dims.size() == 0) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), self); return success(); } - auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims, - options.dimSizeIndexBits); + auto newDimSizesInfo = hlo::getDimSizesOfTensor(rewriter, op, self, dims, + options.dimSizeIndexBits); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( op, "failed to get dimension sizes of the input"); auto newDimSizes = *newDimSizesInfo; - auto mhloShape = + auto stablehloShape = rewriter.create(op.getLoc(), newDimSizes); - rewriter.replaceOpWithNewOp( - op, getTypeConverter()->convertType(op.getType()), self, mhloShape); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self, stablehloShape); return success(); } @@ -395,8 +397,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - auto unsqzTensorInfo = mhlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), - {dim}, options.dimSizeIndexBits); + auto unsqzTensorInfo = hlo::unsqueezeTensor(rewriter, op, adaptor.getSelf(), + {dim}, options.dimSizeIndexBits); if (failed(unsqzTensorInfo)) return rewriter.notifyMatchFailure(op, "failed to create unsqueezed tensor"); @@ -405,9 +407,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } -void mlir::torch::torch_to_mhlo::populateViewLikeOpPatternsAndLegality( +void mlir::torch::torch_to_stablehlo::populateViewLikeOpPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, - ConversionTarget &target, const TorchToMhloOptions &options) { + ConversionTarget &target, const TorchToStablehloOptions &options) { MLIRContext *context = patterns.getContext(); #define INSERT_ATENOP_PATTERN(AtenOp) \ diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index eaa15b00eb14..a5d5f9b7072c 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -11,7 +11,7 @@ set(LinkedLibs MLIRIR TorchMLIRTorchConversionToMLProgram MLIRMemRefTransforms) -if(TORCH_MLIR_ENABLE_MHLO) +if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND LinkedLibs ChloPasses) endif() @@ -21,7 +21,7 @@ add_mlir_library(TorchMLIRTorchConversionPasses Passes.cpp VerifyLinalgOnTensorsBackendContract.cpp VerifyTosaBackendContract.cpp - VerifyMhloBackendContract.cpp + VerifyStablehloBackendContract.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Dialect/TorchConversion/Transforms diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index ffffce2449b6..14d8f360bfe1 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -21,9 +21,8 @@ #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" -#ifdef TORCH_MLIR_ENABLE_MHLO -#include "mhlo/transforms/passes.h" -#include "torch-mlir/Conversion/TorchToMhlo/TorchToMhlo.h" +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -53,12 +52,13 @@ void mlir::torch::registerTorchConversionPasses() { "Pipeline lowering torch backend contract to TOSA backend " "contract.", TorchConversion::createTorchBackendToTosaBackendPipeline); -#ifdef TORCH_MLIR_ENABLE_MHLO - mlir::PassPipelineRegistration( - "torch-backend-to-mhlo-backend-pipeline", - "Pipeline lowering torch backend contract to MHLO backend " +#ifdef TORCH_MLIR_ENABLE_STABLEHLO + mlir::PassPipelineRegistration< + TorchConversion::StablehloBackendPipelineOptions>( + "torch-backend-to-stablehlo-backend-pipeline", + "Pipeline lowering torch backend contract to StableHLO backend " "contract.", - TorchConversion::createTorchBackendToMhloBackendPipeline); + TorchConversion::createTorchBackendToStablehloBackendPipeline); #endif } @@ -121,11 +121,12 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( pm.addPass(TorchConversion::createVerifyTosaBackendContractPass()); } -#ifdef TORCH_MLIR_ENABLE_MHLO -void TorchConversion::createTorchBackendToMhloBackendPipeline( +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +void TorchConversion::createTorchBackendToStablehloBackendPipeline( OpPassManager &pm, - const TorchConversion::MhloBackendPipelineOptions &options) { - pm.addNestedPass(createConvertTorchToMhloPass( + const TorchConversion::StablehloBackendPipelineOptions &options) { + // Generate Stablehlo ops. + pm.addNestedPass(createConvertTorchToStablehloPass( options.enableStaticShape, options.enableI32Index)); // Clean up any non-canonical code introduced above.. @@ -133,21 +134,13 @@ void TorchConversion::createTorchBackendToMhloBackendPipeline( // The resolution of `dim` ops tends to create identical ops. CSE them. pm.addNestedPass(createCSEPass()); - // Convert CHLO ops to MHLO ops - pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); - // Clean up any non-canonical code introduced above.. - pm.addNestedPass(createCanonicalizerPass()); - // The resolution of `dim` ops tends to create identical ops. CSE them. - pm.addNestedPass(createCSEPass()); - // Finish the type conversion from `torch` types to the types of the - // MHLO backend contract. + // StableHLO backend contract. pm.addPass(TorchConversion::createFuncBackendTypeConversionPass()); pm.addNestedPass( TorchConversion::createFinalizingBackendTypeConversionPass()); - // Verify that we have lowered to the form that MHLO backends - // expect. This fails compilation (signalPassFailure) if the IR is not in the - // correct form. - pm.addPass(TorchConversion::createVerifyMhloBackendContractPass()); + + // Verify that we have lowered to Stablehlo and Chlo ops. + pm.addPass(TorchConversion::createVerifyStablehloBackendContractPass()); } #endif diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp similarity index 66% rename from lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp rename to lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index aebf27599978..888f29adedb2 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyMhloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -6,10 +6,9 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// -#ifdef TORCH_MLIR_ENABLE_MHLO +#ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "PassDetail.h" -#include "mhlo/IR/hlo_ops.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Shape/IR/Shape.h" @@ -18,6 +17,7 @@ #include "mlir/IR/OpDefinition.h" #include "mlir/Transforms/DialectConversion.h" #include "stablehlo/dialect/ChloOps.h" +#include "stablehlo/dialect/StablehloOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; @@ -25,17 +25,15 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; namespace { -class VerifyMhloBackendContractPass - : public VerifyMhloBackendContractBase { +class VerifyStablehloBackendContractPass + : public VerifyStablehloBackendContractBase< + VerifyStablehloBackendContractPass> { void runOnOperation() override { - MLIRContext *context = &getContext(); - auto module = getOperation(); TypeConverter converter; converter.addConversion([](Type type) -> Type { auto elemTy = type; - if (isa(type)) { + if (isa(type)) elemTy = type.cast().getElementType(); - } if (BaseMemRefType::isValidElementType(elemTy)) return type; return nullptr; @@ -43,6 +41,7 @@ class VerifyMhloBackendContractPass auto opHasLegalTypes = [&](Operation *op) { return converter.isLegal(op); }; + MLIRContext *context = &getContext(); ConversionTarget target(*context); // Structural operations. @@ -50,26 +49,16 @@ class VerifyMhloBackendContractPass // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); - target.addLegalDialect(); target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); - - RewritePatternSet patterns(context); - if (failed(applyFullConversion(module, target, std::move(patterns)))) { - // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics - // doesn't unnecessarily spew out the entire module. - emitError(module.getLoc()) - << "Module does not conform to the MHLO backend contract. " - "See dialect conversion legality information above."; - return signalPassFailure(); - } } }; } // namespace std::unique_ptr> -mlir::torch::TorchConversion::createVerifyMhloBackendContractPass() { - return std::make_unique(); +mlir::torch::TorchConversion::createVerifyStablehloBackendContractPass() { + return std::make_unique(); } -#endif // TORCH_MLIR_ENABLE_MHLO +#endif // TORCH_MLIR_ENABLE_STABLEHLO diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index b01d62152b25..87a2b8f3962c 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -20,6 +20,10 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "mhlo/transforms/passes.h" +#endif + void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); @@ -34,4 +38,11 @@ void mlir::torch::registerAllPasses() { mlir::torch::registerConversionPasses(); mlir::torch::RefBackend::registerRefBackendPasses(); mlir::torch::TMTensor::registerPasses(); + +#ifdef TORCH_MLIR_ENABLE_STABLEHLO + mlir::mhlo::registerSymbolicShapeOptimizationPass(); + mlir::mhlo::registerStablehloLegalizeToHloPass(); + mlir::mhlo::registerChloLegalizeToHloPass(); + mlir::mhlo::registerHloLegalizeToLinalgPass(); +#endif // TORCH_MLIR_ENABLE_STABLEHLO } diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 3f08bb17365d..443512a6d54a 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -44,9 +44,9 @@ class OutputType(Enum): # as taking the `TORCH` output type and lowering it to TOSA. TOSA = "tosa" - # This output type consists of `mhlo` dialect ops. It can be thought of - # as taking the `TORCH` output type and lowering it to MHLO. - MHLO = "mhlo" + # This output type consists of `stablehlo` dialect ops. It can be thought of + # as taking the `TORCH` output type and lowering it to StableHLO. + STABLEHLO = "stablehlo" # Raw output of the JIT IR importer. This is not expected to be useful # for end-users, but can be convenient for development or reporting bugs. @@ -242,7 +242,7 @@ def _get_for_tracing( BACKEND_LEGAL_OPS = { OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'], OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', ], - OutputType.MHLO: [], + OutputType.STABLEHLO: [], } @@ -290,7 +290,7 @@ def compile(model: torch.nn.Module, # We only allow `backend_legal_ops` to be specified for the `"torch"` # output type because the other output types actually invoke their - # respective backends (Linalg, TOSA, or MHLO), and those backends have + # respective backends (Linalg, TOSA, or STABLEHLO), and those backends have # very specific requirements about the ops which are legal. # See `BACKEND_LEGAL_OPS` for more details. if backend_legal_ops is not None: @@ -404,14 +404,14 @@ def compile(model: torch.nn.Module, print(mb.module) return mb.module - elif output_type == OutputType.MHLO: + elif output_type == OutputType.STABLEHLO: run_pipeline_with_repro_report( mb.module, - "builtin.module(torch-backend-to-mhlo-backend-pipeline)", - "Lowering Torch Backend IR -> MHLO Backend IR") + "builtin.module(torch-backend-to-stablehlo-backend-pipeline)", + "Lowering Torch Backend IR -> StableHLO Backend IR") if verbose: print("\n====================") - print("MHLO Backend IR") + print("StableHLO Backend IR") print(mb.module) return mb.module raise Exception(f"Unknown OutputType: {output_type}") diff --git a/python/torch_mlir_e2e_test/configs/__init__.py b/python/torch_mlir_e2e_test/configs/__init__.py index 36fab40bd4be..4ca4c3dce803 100644 --- a/python/torch_mlir_e2e_test/configs/__init__.py +++ b/python/torch_mlir_e2e_test/configs/__init__.py @@ -7,6 +7,6 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig -from .mhlo_backend import MhloBackendTestConfig +from .stablehlo_backend import StablehloBackendTestConfig from .tosa_backend import TosaBackendTestConfig from .torchdynamo import TorchDynamoTestConfig diff --git a/python/torch_mlir_e2e_test/configs/mhlo_backend.py b/python/torch_mlir_e2e_test/configs/stablehlo_backend.py similarity index 74% rename from python/torch_mlir_e2e_test/configs/mhlo_backend.py rename to python/torch_mlir_e2e_test/configs/stablehlo_backend.py index 0b7b3253499a..45f32bb0b3fe 100644 --- a/python/torch_mlir_e2e_test/configs/mhlo_backend.py +++ b/python/torch_mlir_e2e_test/configs/stablehlo_backend.py @@ -8,12 +8,8 @@ import torch import torch_mlir -from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend -from torch_mlir_e2e_test.framework import ( - TestConfig, - Trace, - TraceItem -) +from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders from .utils import ( recursively_convert_to_numpy, @@ -21,20 +17,20 @@ ) -class MhloBackendTestConfig(TestConfig): +class StablehloBackendTestConfig(TestConfig): """Base class for TestConfig's that are implemented with linalg-on-tensors. This class handles all the common lowering that torch-mlir does before reaching the linalg-on-tensors abstraction level. """ - def __init__(self, backend: MhloBackend): + + def __init__(self, backend: StablehloBackend): super().__init__() self.backend = backend def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( - program, example_args, output_type="mhlo") + module = torch_mlir.compile(program, example_args, output_type="stablehlo") return self.backend.compile(module) @@ -46,7 +42,6 @@ def run(self, artifact: Any, trace: Trace) -> Trace: outputs = getattr(backend_module, item.symbol)(*numpy_inputs) output = recursively_convert_from_numpy(outputs) result.append( - TraceItem(symbol=item.symbol, - inputs=item.inputs, - output=output)) + TraceItem(symbol=item.symbol, inputs=item.inputs, output=output) + ) return result diff --git a/python/torch_mlir_e2e_test/mhlo_backends/__init__.py b/python/torch_mlir_e2e_test/stablehlo_backends/__init__.py similarity index 100% rename from python/torch_mlir_e2e_test/mhlo_backends/__init__.py rename to python/torch_mlir_e2e_test/stablehlo_backends/__init__.py diff --git a/python/torch_mlir_e2e_test/mhlo_backends/abc.py b/python/torch_mlir_e2e_test/stablehlo_backends/abc.py similarity index 76% rename from python/torch_mlir_e2e_test/mhlo_backends/abc.py rename to python/torch_mlir_e2e_test/stablehlo_backends/abc.py index 8fc51ac00f7a..dbecbcc266d5 100644 --- a/python/torch_mlir_e2e_test/mhlo_backends/abc.py +++ b/python/torch_mlir_e2e_test/stablehlo_backends/abc.py @@ -10,29 +10,30 @@ from torch_mlir.ir import Module -# A type shared between the result of `MhloBackend.compile` and the -# input to `MhloBackend.load`. Each backend will likely have a +# A type shared between the result of `StablehloBackend.compile` and the +# input to `StablehloBackend.load`. Each backend will likely have a # different definition of this type. -CompiledArtifact = TypeVar('CompiledArtifact') +CompiledArtifact = TypeVar("CompiledArtifact") # A wrapper around a backend-specific loaded program representation # that uniformly translates the `x.method(...)` interface expected of # Torch modules into appropriate lower-level operations. -Invoker = TypeVar('Invoker') +Invoker = TypeVar("Invoker") -class MhloBackend(abc.ABC): - """The interface to an MHLO backend. +class StablehloBackend(abc.ABC): + """The interface to an StableHLO backend. Backends are recommended to raise meaningful exceptions in case of error, ideally with easy reproduction instructions. """ + @abc.abstractmethod def compile(self, module: Module) -> CompiledArtifact: """Compile the provided MLIR module into a compiled artifact. - The module adheres to the MHLO backend contract - (see the VerifyMhloBackendContract pass). + The module adheres to the StableHLO backend contract + (see the VerifyStablehloBackendContract pass). The compiled artifact can be any type, but must be correctly interpreted by the `load` method. diff --git a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py similarity index 66% rename from python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py rename to python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 3ac1d6cd6675..b285c46b8f96 100644 --- a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py +++ b/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -7,28 +7,32 @@ from torch_mlir.passmanager import * from torch_mlir.compiler_utils import run_pipeline_with_repro_report -from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) -from .abc import MhloBackend +from .abc import StablehloBackend __all__ = [ - "LinalgOnTensorsMhloBackend", + "LinalgOnTensorsStablehloBackend", ] -class LinalgOnTensorsMhloBackend(MhloBackend): - """Main entry-point for the linalg-on-tensors based MHLO backend. + +class LinalgOnTensorsStablehloBackend(StablehloBackend): + """Main entry-point for the linalg-on-tensors based StableHLO backend. This currently uses the linalg-on-tensors RefBackend for actual execution. """ + def __init__(self): super().__init__() self.refbackend = RefBackendLinalgOnTensorsBackend() def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the MHLO backend contract. + """Compiles an imported module that satisfied the StableHLO backend contract. Args: - imported_module: The MLIR module consisting of funcs in the MHLO + imported_module: The MLIR module consisting of funcs in the StableHLO dialect. Returns: An opaque, backend specific compiled artifact object that can be @@ -36,8 +40,9 @@ def compile(self, imported_module: Module): """ run_pipeline_with_repro_report( imported_module, - "builtin.module(func.func(symbolic-shape-optimization),func.func(hlo-legalize-to-linalg),func.func(canonicalize))", - "Lowering MLIR-HLO to Linalg-on-Tensors") + "builtin.module(func.func(chlo-legalize-to-hlo),stablehlo-legalize-to-hlo,func.func(canonicalize,cse,symbolic-shape-optimization,hlo-legalize-to-linalg,canonicalize))", + "Lowering StableHLO to Linalg-on-Tensors", + ) return self.refbackend.compile(imported_module) def load(self, module): diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 8b444bd1dfbc..51407b488ddf 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -1,7 +1,7 @@ llvm_canonicalize_cmake_booleans( MLIR_ENABLE_BINDINGS_PYTHON TORCH_MLIR_ENABLE_JIT_IR_IMPORTER - TORCH_MLIR_ENABLE_MHLO + TORCH_MLIR_ENABLE_STABLEHLO ) configure_lit_site_cfg( diff --git a/test/Conversion/TorchToMhlo/basic.mlir b/test/Conversion/TorchToMhlo/basic.mlir index bea58bd409fd..aae5c91e7120 100644 --- a/test/Conversion/TorchToMhlo/basic.mlir +++ b/test/Conversion/TorchToMhlo/basic.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s // ----- @@ -7,7 +7,7 @@ // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = mhlo.copy %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -19,7 +19,7 @@ func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<0.000000e+00> : tensor // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[VAL_1]] : !torch.vtensor<[],f32> func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { @@ -30,7 +30,7 @@ func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { -// CHECK: %[[VAL_0:.*]] = mhlo.constant dense<1> : tensor<2xi64> +// CHECK: %[[VAL_0:.*]] = stablehlo.constant dense<1> : tensor<2xi64> // CHECK: %[[VAL_1:.*]] = torch_c.from_builtin_tensor %[[VAL_0]] : tensor<2xi64> -> !torch.vtensor<[2],si64> // CHECK: return %[[VAL_1]] : !torch.vtensor<[2],si64> func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { @@ -45,8 +45,8 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] // CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> // CHECK: return %[[T4]] : !torch.vtensor<[],si64> func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { @@ -75,7 +75,7 @@ func.func @torch.aten.contiguous(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vt // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = "chlo.constant_like"(%[[VAL_1]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[VAL_3:.*]] = mhlo.divide %[[VAL_2]], %[[VAL_1]] : tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.divide %[[VAL_2]], %[[VAL_1]] : tensor // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -91,7 +91,7 @@ func.func @torch.aten.reciprocal(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.v // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,3],f32> -> tensor<4x3xf32> // CHECK: %[[VAL_2:.*]] = torch.constant.int 0 // CHECK: %[[VAL_3:.*]] = torch.constant.int 1 -// CHECK: %[[VAL_4:.*]] = "mhlo.transpose"(%[[VAL_1]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x3xf32>) -> tensor<3x4xf32> +// CHECK: %[[VAL_4:.*]] = stablehlo.transpose %[[VAL_1]], dims = [1, 0] : (tensor<4x3xf32>) -> tensor<3x4xf32> // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<3x4xf32> -> !torch.vtensor<[3,4],f32> // CHECK: return %[[VAL_5]] : !torch.vtensor<[3,4],f32> func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,4],f32> { @@ -118,7 +118,7 @@ func.func @torch.aten.transpose$basic(%arg0: !torch.vtensor<[4,3],f32>) -> !torc // CHECK: %[[VAL_7:.*]] = arith.constant 1 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1:.*]], %[[VAL_7]] : tensor // CHECK: %[[VAL_9:.*]] = tensor.from_elements %[[VAL_4]], %[[VAL_6]], %[[VAL_8]] : tensor<3xindex> -// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_1]], %[[VAL_9]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor, tensor<3xindex>) -> tensor<8x4x?xf32> +// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_1]], %[[VAL_9]], dims = [1, 2] : (tensor, tensor<3xindex>) -> tensor<8x4x?xf32> // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_10]] : tensor<8x4x?xf32> -> !torch.vtensor<[8,4,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[8,4,?],f32> func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[8,4,?],f32> { @@ -135,15 +135,15 @@ func.func @torch.aten.broadcast_to$dynamic_implicit(%arg0: !torch.vtensor<[?,?], // CHECK-LABEL: func.func @torch.aten.batch_norm$training( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_7:.*]], %[[VAL_8:.*]], %[[VAL_9:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_3]], %[[VAL_2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_8:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_8]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -161,8 +161,8 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK-LABEL: func.func @torch.aten.batch_norm$inference( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[T2:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[T1:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[T2:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[FLOAT1:.*]].000000e-01 = torch.constant.float 1.000000e-01 @@ -171,7 +171,7 @@ func.func @torch.aten.batch_norm$training(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK: %[[T3:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T4:.*]] = tensor.from_elements %[[T3]] : tensor<1xindex> // CHECK: %[[T5:.*]] = tensor.cast %[[T0]] : tensor to tensor -// CHECK: %[[T6:.*]] = "mhlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor +// CHECK: %[[T6:.*]] = "stablehlo.batch_norm_inference"(%[[T5]], %[[T2]], %[[T1]], %[[T1]], %[[T2]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor // CHECK: %[[T7:.*]] = tensor.cast %[[T6]] : tensor to tensor // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[T8]] : !torch.vtensor<[?,3,?,?],f32> @@ -192,19 +192,19 @@ func.func @torch.aten.batch_norm$inference(%arg0: !torch.vtensor<[?,3,?,?],f32>) // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,3,?,?],f32> -> tensor // CHECK: %none = torch.constant.none -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<3xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<3xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<3xf32> // CHECK: %true = torch.constant.bool true // CHECK: %float1.000000e-01 = torch.constant.float 1.000000e-01 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %[[VAL_4:.*]] = arith.constant 1 : index // CHECK: %[[VAL_5:.*]] = tensor.dim %[[VAL_1]], %[[VAL_4]] : tensor // CHECK: %[[VAL_6:.*]] = tensor.from_elements %[[VAL_5]] : tensor<1xindex> -// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor -// CHECK: %[[VAL_8:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> -// CHECK: %[[VAL_9:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_10:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_9]], %[[VAL_6]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<1xindex>) -> tensor<3xf32> -// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "mhlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) +// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_8:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_9:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_9]], %[[VAL_6]], dims = [] : (tensor, tensor<1xindex>) -> tensor<3xf32> +// CHECK: %[[VAL_11:.*]], %[[VAL_12:.*]], %[[VAL_13:.*]] = "stablehlo.batch_norm_training"(%[[VAL_1]], %[[VAL_8]], %[[VAL_10]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor, tensor<3xf32>, tensor<3xf32>) -> (tensor, tensor<3xf32>, tensor<3xf32>) // CHECK: %[[VAL_14:.*]] = torch_c.from_builtin_tensor %[[VAL_11]] : tensor -> !torch.vtensor<[?,3,?,?],f32> // CHECK: return %[[VAL_14]] : !torch.vtensor<[?,3,?,?],f32> func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?],f32>) -> !torch.vtensor<[?,3,?,?],f32> { @@ -222,28 +222,28 @@ func.func @torch.aten.batch_norm$no_bias_weight(%arg0: !torch.vtensor<[?,3,?,?], // CHECK-LABEL: func @torch.aten.native_layer_norm( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,7,4,5],f32> -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x5xf32> -// CHECK: %[[VAL_3:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x5xf32> +// CHECK: %[[VAL_2:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<4x5xf32> +// CHECK: %[[VAL_3:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<4x5xf32> // CHECK: %int4 = torch.constant.int 4 // CHECK: %int5 = torch.constant.int 5 // CHECK: %float1.000000e-05 = torch.constant.float 1.000000e-05 // CHECK: %true = torch.constant.bool true // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int4, %int5 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<[1, 21, 20]> : tensor<3xi64> -// CHECK: %[[VAL_6:.*]] = mhlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> -// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor<21xf32> -// CHECK: %[[VAL_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor<21xf32> -// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "mhlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -// CHECK: %[[VAL_12:.*]] = mhlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> -// CHECK: %[[VAL_13:.*]] = mhlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_14:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> -// CHECK: %[[VAL_15:.*]] = mhlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> -// CHECK: %[[VAL_16:.*]] = mhlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> -// CHECK: %[[VAL_17:.*]] = mhlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> -// CHECK: %[[VAL_18:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_3]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_19:.*]] = "mhlo.broadcast_in_dim"(%[[VAL_2]]) {broadcast_dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> -// CHECK: %[[VAL_20:.*]] = mhlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> -// CHECK: %[[VAL_21:.*]] = mhlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<[1, 21, 20]> : tensor<3xi64> +// CHECK: %[[VAL_6:.*]] = stablehlo.dynamic_reshape %[[VAL_1]], %[[VAL_5]] : (tensor<3x7x4x5xf32>, tensor<3xi64>) -> tensor<1x21x20xf32> +// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor<21xf32> +// CHECK: %[[VAL_9:.*]], %[[VAL_10:.*]], %[[VAL_11:.*]] = "stablehlo.batch_norm_training"(%[[VAL_6]], %[[VAL_7]], %[[VAL_8]]) {epsilon = 9.99999974E-6 : f32, feature_index = 1 : i64} : (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) -> (tensor<1x21x20xf32>, tensor<21xf32>, tensor<21xf32>) +// CHECK: %[[VAL_12:.*]] = stablehlo.constant dense<[3, 7, 4, 5]> : tensor<4xi64> +// CHECK: %[[VAL_13:.*]] = stablehlo.dynamic_reshape %[[VAL_9]], %[[VAL_12]] : (tensor<1x21x20xf32>, tensor<4xi64>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_14:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_15:.*]] = stablehlo.dynamic_reshape %[[VAL_10]], %[[VAL_14]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_16:.*]] = stablehlo.constant dense<[3, 7, 1, 1]> : tensor<4xi64> +// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_reshape %[[VAL_11]], %[[VAL_16]] : (tensor<21xf32>, tensor<4xi64>) -> tensor<3x7x1x1xf32> +// CHECK: %[[VAL_18:.*]] = stablehlo.broadcast_in_dim %[[VAL_3]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_19:.*]] = stablehlo.broadcast_in_dim %[[VAL_2]], dims = [2, 3] : (tensor<4x5xf32>) -> tensor<3x7x4x5xf32> +// CHECK: %[[VAL_20:.*]] = stablehlo.multiply %[[VAL_13]], %[[VAL_18]] : tensor<3x7x4x5xf32> +// CHECK: %[[VAL_21:.*]] = stablehlo.add %[[VAL_20]], %[[VAL_19]] : tensor<3x7x4x5xf32> // CHECK: %[[VAL_22:.*]] = torch_c.from_builtin_tensor %[[VAL_21:.*]] : tensor<3x7x4x5xf32> -> !torch.vtensor<[3,7,4,5],f32> // CHECK: return %[[VAL_22]] : !torch.vtensor<[3,7,4,5],f32> func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> !torch.vtensor<[3,7,4,5],f32> { @@ -267,8 +267,8 @@ func.func @torch.aten.native_layer_norm(%arg0: !torch.vtensor<[3,7,4,5],f32>) -> // CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor -// CHECK: %[[T3:.*]] = mhlo.convert %[[T2]] : (tensor) -> tensor -// CHECK: %[[T4:.*]] = "mhlo.concatenate"(%[[T1]], %[[T3]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[T3:.*]] = stablehlo.convert %[[T2]] : (tensor) -> tensor +// CHECK: %[[T4:.*]] = stablehlo.concatenate %[[T1]], %[[T3]], dim = 0 : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { @@ -287,7 +287,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc // CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[VAL_3:.*]] = "mhlo.concatenate"(%[[VAL_1]], %[[VAL_2]]) {dimension = 0 : i64} : (tensor, tensor) -> tensor +// CHECK: %[[VAL_3:.*]] = stablehlo.concatenate %[[VAL_1]], %[[VAL_2]], dim = 0 : (tensor, tensor) -> tensor // CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { diff --git a/test/Conversion/TorchToMhlo/elementwise.mlir b/test/Conversion/TorchToMhlo/elementwise.mlir index 6b3faace05f3..b1d560e4fd7e 100644 --- a/test/Conversion/TorchToMhlo/elementwise.mlir +++ b/test/Conversion/TorchToMhlo/elementwise.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.gelu( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -7,12 +7,12 @@ // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 1.000000e+00 : f32} : (tensor) -> tensor // CHECK: %[[T2:.*]] = "chlo.constant_like"(%[[T0]]) {value = 2.000000e+00 : f32} : (tensor) -> tensor // CHECK: %[[T3:.*]] = "chlo.constant_like"(%[[T0]]) {value = 5.000000e-01 : f32} : (tensor) -> tensor -// CHECK: %[[T4:.*]] = mhlo.rsqrt %[[T2]] : tensor -// CHECK: %[[T5:.*]] = mhlo.multiply %[[T0]], %[[T4]] : tensor +// CHECK: %[[T4:.*]] = stablehlo.rsqrt %[[T2]] : tensor +// CHECK: %[[T5:.*]] = stablehlo.multiply %[[T0]], %[[T4]] : tensor // CHECK: %[[T6:.*]] = chlo.erf %[[T5]] : tensor -> tensor -// CHECK: %[[T7:.*]] = mhlo.add %[[T6]], %[[T1]] : tensor -// CHECK: %[[T8:.*]] = mhlo.multiply %[[T7]], %[[T3]] : tensor -// CHECK: %[[T9:.*]] = mhlo.multiply %[[T0]], %[[T8]] : tensor +// CHECK: %[[T7:.*]] = stablehlo.add %[[T6]], %[[T1]] : tensor +// CHECK: %[[T8:.*]] = stablehlo.multiply %[[T7]], %[[T3]] : tensor +// CHECK: %[[T9:.*]] = stablehlo.multiply %[[T0]], %[[T8]] : tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -26,7 +26,7 @@ func.func @torch.aten.gelu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ // CHECK-LABEL: func.func @torch.aten.tanh$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.tanh %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.tanh %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -39,7 +39,7 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte // CHECK-LABEL: func.func @torch.aten.log$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.log %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.log %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -52,7 +52,7 @@ func.func @torch.aten.log$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.exp$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.exponential %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.exponential %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -65,7 +65,7 @@ func.func @torch.aten.exp$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.neg$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.negate %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.negate %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -78,7 +78,7 @@ func.func @torch.aten.neg$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vten // CHECK-LABEL: func.func @torch.aten.rsqrt$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.rsqrt %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.rsqrt %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -91,7 +91,7 @@ func.func @torch.aten.rsqrt$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vt // CHECK-LABEL: func.func @torch.aten.sigmoid$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.logistic %[[T0]] : tensor +// CHECK: %[[T1:.*]] = stablehlo.logistic %[[T0]] : tensor // CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -108,8 +108,8 @@ func.func @torch.aten.sigmoid$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch. // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_add %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -130,11 +130,11 @@ func.func @torch.aten.addscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_add %[[T0]], %[[T7]] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?],f32> @@ -171,8 +171,8 @@ func.func @torch.aten.addtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> @@ -190,7 +190,7 @@ func.func @torch.aten.addtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor) -> tensor // CHECK: %[[T3:.*]] = chlo.broadcast_add %[[T2]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> @@ -209,8 +209,8 @@ func.func @torch.aten.addtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -230,8 +230,8 @@ func.func @torch.aten.subscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torc // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_subtract %[[T3]], %[[T0]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -252,11 +252,11 @@ func.func @torch.aten.rsubscalar$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !tor // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T5:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T6:.*]] = mhlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T5:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T6:.*]] = stablehlo.reshape %[[T5]] : (tensor<1xf32>) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_multiply %[[T4]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_subtract %[[T0]], %[[T7]] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?],f32> @@ -293,8 +293,8 @@ func.func @torch.aten.subtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[T2:.*]] = torch_c.to_i64 %[[INT2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_subtract %[[T0]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> @@ -312,7 +312,7 @@ func.func @torch.aten.subtensor$alpha(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],si32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si64> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T2:.*]] = mhlo.convert %[[T0]] : (tensor) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[T0]] : (tensor) -> tensor // CHECK: %[[T3:.*]] = chlo.broadcast_subtract %[[T2]], %[[T1]] : (tensor, tensor) -> tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],si64> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],si64> @@ -330,8 +330,8 @@ func.func @torch.aten.subtensor$promote(%arg0: !torch.vtensor<[?,?],si32>, %arg1 // CHECK: %[[INT9:.*]] = torch.constant.int 9 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -363,8 +363,8 @@ func.func @torch.aten.multensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT9:.*]] = torch.constant.int 9 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT9]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -396,8 +396,8 @@ func.func @torch.aten.divtensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[INT3]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> @@ -471,7 +471,7 @@ func.func @torch.aten.ne.tensor(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch. // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = "mhlo.transpose"(%[[T0]]) {permutation = dense<[1, 0]> : tensor<2xi64>} : (tensor<4x64xf32>) -> tensor<64x4xf32> +// CHECK: %[[T2:.*]] = stablehlo.transpose %[[T0]], dims = [1, 0] : (tensor<4x64xf32>) -> tensor<64x4xf32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<64x4xf32> -> !torch.vtensor<[64,4],f32> // CHECK: return %[[T3]] : !torch.vtensor<[64,4],f32> func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch.vtensor<[64,4],f32> { @@ -488,7 +488,7 @@ func.func @torch.aten.permute$basic(%arg0: !torch.vtensor<[4,64],f32>) -> !torch // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = "chlo.constant_like"(%[[T0]]) {value = 0.000000e+00 : f32} : (tensor) -> tensor -// CHECK: %[[T2:.*]] = mhlo.maximum %[[T0]], %[[T1]] : tensor +// CHECK: %[[T2:.*]] = stablehlo.maximum %[[T0]], %[[T1]] : tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T3]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -503,11 +503,11 @@ func.func @torch.aten.relu(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[ // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_f64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[FROM_ELEMENTS_0:.*]] = tensor.from_elements %[[T1]] : tensor<1xf64> -// CHECK: %[[T4:.*]] = mhlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T5:.*]] = mhlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T4:.*]] = stablehlo.convert %[[FROM_ELEMENTS_0]] : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T5:.*]] = stablehlo.reshape %[[T4]] : (tensor<1xf32>) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_multiply %[[T3]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = chlo.broadcast_add %[[T0]], %[[T6]] : (tensor, tensor) -> tensor // CHECK: %[[T8:.*]] = torch_c.from_builtin_tensor %[[T7]] : tensor -> !torch.vtensor<[?,?],f32> @@ -525,8 +525,8 @@ func.func @torch.aten.addscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T2:.*]] = torch_c.to_f64 %[[ARG2]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]] : tensor<1xf64> -// CHECK: %[[T3:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> -// CHECK: %[[T4:.*]] = mhlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T3:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xf64>) -> tensor<1xf32> +// CHECK: %[[T4:.*]] = stablehlo.reshape %[[T3]] : (tensor<1xf32>) -> tensor // CHECK: %[[T5:.*]] = chlo.broadcast_multiply %[[T1]], %[[T4]] : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = chlo.broadcast_add %[[T0]], %[[T5]] : (tensor, tensor) -> tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> @@ -543,8 +543,8 @@ func.func @torch.aten.addtensor$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_multiply %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -560,8 +560,8 @@ func.func @torch.aten.mulscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_divide %[[T0]], %[[T3]] : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> @@ -577,8 +577,8 @@ func.func @torch.aten.divscalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_i64 %[[ARG1]] // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = mhlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> -// CHECK: %[[T3:.*]] = mhlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.convert %[[FROM_ELEMENTS]] : (tensor<1xi64>) -> tensor<1xf32> +// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xf32>) -> tensor // CHECK: %[[T4:.*]] = chlo.broadcast_compare %[[T0]], %[[T3]] {compare_type = #chlo, comparison_direction = #chlo} : (tensor, tensor) -> tensor // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],i1> // CHECK: return %[[T5]] : !torch.vtensor<[?,?],i1> @@ -595,10 +595,10 @@ func.func @torch.aten.gt.scalar$variable(%arg0: !torch.vtensor<[?,?],f32>, %arg1 // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "trunc" // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = mhlo.sign %[[T2]] : tensor -// CHECK: %[[T4:.*]] = mhlo.abs %[[T2]] : tensor -// CHECK: %[[T5:.*]] = mhlo.floor %[[T4]] : tensor -// CHECK: %[[T6:.*]] = mhlo.multiply %[[T3]], %[[T5]] : tensor +// CHECK: %[[T3:.*]] = stablehlo.sign %[[T2]] : tensor +// CHECK: %[[T4:.*]] = stablehlo.abs %[[T2]] : tensor +// CHECK: %[[T5:.*]] = stablehlo.floor %[[T4]] : tensor +// CHECK: %[[T6:.*]] = stablehlo.multiply %[[T3]], %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -615,7 +615,7 @@ func.func @torch.aten.div.Tensor_mode$trunc(%arg0: !torch.vtensor<[?,?,?,?],f32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[STR:.*]] = torch.constant.str "floor" // CHECK: %[[T2:.*]] = chlo.broadcast_divide %[[T0]], %[[T1]] : (tensor, tensor) -> tensor -// CHECK: %[[T3:.*]] = mhlo.floor %[[T2]] : tensor +// CHECK: %[[T3:.*]] = stablehlo.floor %[[T2]] : tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.div.Tensor_mode$floor(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { diff --git a/test/Conversion/TorchToMhlo/gather.mlir b/test/Conversion/TorchToMhlo/gather.mlir index a20b32d4994d..ea4ca9b8272e 100644 --- a/test/Conversion/TorchToMhlo/gather.mlir +++ b/test/Conversion/TorchToMhlo/gather.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.index_select$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,4],f32>, %[[ARG1:.*]]: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { @@ -10,8 +10,8 @@ // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> -// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor<2x4xf32> +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<2x4xf32> +// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor<2x4xf32> // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor<2x4xf32> -> !torch.vtensor<[2,4],f32> // CHECK: return %[[T7]] : !torch.vtensor<[2,4],f32> func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,4],f32> { @@ -31,8 +31,8 @@ func.func @torch.aten.index_select$basic(%arg0: !torch.vtensor<[?,4],f32>, %arg1 // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,?],f32> func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?], si64>) -> !torch.vtensor<[?,?],f32> { @@ -53,8 +53,8 @@ func.func @torch.aten.embedding$basic(%weight: !torch.vtensor<[?,?],f32>, %indic // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 // CHECK: %[[T4:.*]] = tensor.from_elements %[[C1_I64]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T5:.*]] = "mhlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #mhlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor -// CHECK: %[[T6:.*]] = mhlo.convert %[[T5]] : tensor +// CHECK: %[[T5:.*]] = "stablehlo.dynamic_gather"(%[[T0]], %[[T1]], %[[T4]]) {dimension_numbers = #stablehlo.gather, indices_are_sorted = false} : (tensor, tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T6:.*]] = stablehlo.convert %[[T5]] : tensor // CHECK: %[[T7:.*]] = torch_c.from_builtin_tensor %[[T6]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T7]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.embedding$rank_two_indices(%weight: !torch.vtensor<[?,?],f32>, %indices: !torch.vtensor<[?,1], si64>) -> !torch.vtensor<[?,1,?],f32> { diff --git a/test/Conversion/TorchToMhlo/linear.mlir b/test/Conversion/TorchToMhlo/linear.mlir index 165c874ea061..628969956684 100644 --- a/test/Conversion/TorchToMhlo/linear.mlir +++ b/test/Conversion/TorchToMhlo/linear.mlir @@ -1,10 +1,10 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.mm$basic$static( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[2,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,3],f32>) -> !torch.vtensor<[2,3],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[2,3],f32> -> tensor<2x3xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,3],f32> -> tensor<3x3xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> +// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<2x3xf32>, tensor<3x3xf32>) -> tensor<2x3xf32> // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor<2x3xf32> to tensor<2x3xf32> // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor<2x3xf32> -> !torch.vtensor<[2,3],f32> // CHECK: return %[[T4]] : !torch.vtensor<[2,3],f32> @@ -19,7 +19,7 @@ func.func @torch.aten.mm$basic$static(%arg0: !torch.vtensor<[2,3],f32>, %arg1: ! // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,3],f32>, %[[ARG1:.*]]: !torch.vtensor<[3,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,3],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[3,?],f32> -> tensor<3x?xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<3x?xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<3x?xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,?],f32> @@ -44,8 +44,8 @@ func.func @torch.aten.mm$basic$dynamic(%arg0: !torch.vtensor<[?,3],f32>, %arg1: // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor<10x4x5xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> -// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor<10x4x5xf32>, tensor<3xi64>) -> tensor<10x4x5xf32> +// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<10x3x4xf32>, tensor<10x4x5xf32>) -> tensor<10x3x5xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<10x3x5xf32> to tensor<10x3x5xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<10x3x5xf32> -> !torch.vtensor<[10,3,5],f32> // CHECK: return %[[T12]] : !torch.vtensor<[10,3,5],f32> @@ -70,8 +70,8 @@ func.func @torch.aten.bmm$basic$static(%arg0: !torch.vtensor<[10,3,4],f32>, %arg // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C2]] : tensor // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[0, 1, 2]> : tensor<3xi64>} : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [0, 1, 2] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,?],f32> @@ -96,8 +96,8 @@ func.func @torch.aten.bmm$basic$dynamic(%arg0: !torch.vtensor<[?,?,4],f32>, %arg // CHECK: %[[T6:.*]] = tensor.dim %[[T0]], %[[C1]] : tensor<256x120xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> -// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T8]], dims = [1, 2] : (tensor<256x120xf32>, tensor<3xi64>) -> tensor<4x256x120xf32> +// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T9]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<4x256x120xf32>, tensor<4x120x256xf32>) -> tensor<4x256x256xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x256x256xf32> to tensor<4x256x256xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x256x256xf32> -> !torch.vtensor<[4,256,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,256,256],f32> @@ -122,8 +122,8 @@ func.func @torch.aten.matmul$basic$static(%arg0: !torch.vtensor<[256,120],f32>, // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x?xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> -// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x?xf32>, tensor<3xi64>) -> tensor<4x256x?xf32> +// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<4x?x256xf32>, tensor<4x256x?xf32>) -> tensor<4x?x?xf32> // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor<4x?x?xf32> to tensor<4x?x?xf32> // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor<4x?x?xf32> -> !torch.vtensor<[4,?,?],f32> // CHECK: return %[[T12]] : !torch.vtensor<[4,?,?],f32> @@ -145,8 +145,8 @@ func.func @torch.aten.matmul$basic$dynamic(%arg0: !torch.vtensor<[4,?,256],f32>, // CHECK: %[[T4:.*]] = tensor.dim %[[T1]], %[[C0_0]] : tensor<256xf32> // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> -// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #mhlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> +// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor<1x256xf32> +// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T7]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor<1x?x256xf32>, tensor<1x256xf32>) -> tensor<1x?xf32> // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor<1x?xf32> to tensor<1x?xf32> // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor<1x?xf32> -> !torch.vtensor<[1,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[1,?],f32> @@ -168,8 +168,8 @@ func.func @torch.aten.matmul$3dx1d(%arg0: !torch.vtensor<[1,?,256],f32>, %arg1: // CHECK: %[[T4:.*]] = tensor.dim %[[T0]], %[[C0_0]] : tensor<256xf32> // CHECK: %[[T5:.*]] = arith.index_cast %[[T4]] : index to i64 // CHECK: %[[T6:.*]] = tensor.from_elements %[[T3]], %[[T5]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T0]], %[[T6]]) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<256xf32>, tensor<2xi64>) -> tensor -// CHECK: %[[T8:.*]] = "mhlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T7:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T0]], %[[T6]], dims = [1] : (tensor<256xf32>, tensor<2xi64>) -> tensor +// CHECK: %[[T8:.*]] = "stablehlo.dot_general"(%[[T7]], %[[T1]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = tensor.cast %[[T8]] : tensor to tensor // CHECK: %[[T10:.*]] = torch_c.from_builtin_tensor %[[T9]] : tensor -> !torch.vtensor<[?,?],f32> // CHECK: return %[[T10]] : !torch.vtensor<[?,?],f32> @@ -184,7 +184,7 @@ func.func @torch.aten.matmul$1dx3d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32> @@ -199,7 +199,7 @@ func.func @torch.aten.matmul$2dx1d(%arg0: !torch.vtensor<[?,256],f32>, %arg1: !t // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256,?],f32>) -> !torch.vtensor<[?],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256,?],f32> -> tensor<256x?xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256x?xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256x?xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?],f32> @@ -214,7 +214,7 @@ func.func @torch.aten.matmul$1dx2d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[256],f32>, %[[ARG1:.*]]: !torch.vtensor<[256],f32>) -> !torch.vtensor<[],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[256],f32> -> tensor<256xf32> // CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[256],f32> -> tensor<256xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor<256xf32>, tensor<256xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor<256xf32>, tensor<256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[T4]] : !torch.vtensor<[],f32> @@ -228,7 +228,7 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK-LABEL: func.func @torch.aten.matmul$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,256],f32>) -> !torch.vtensor<[?,?,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,256],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[T2:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T3:.*]] = arith.index_cast %[[T2]] : index to i64 @@ -239,8 +239,8 @@ func.func @torch.aten.matmul$1dx1d(%arg0: !torch.vtensor<[256],f32>, %arg1: !tor // CHECK: %[[T6:.*]] = tensor.dim %[[T1]], %[[C1]] : tensor<256x256xf32> // CHECK: %[[T7:.*]] = arith.index_cast %[[T6]] : index to i64 // CHECK: %[[T8:.*]] = tensor.from_elements %[[T3]], %[[T5]], %[[T7]] : tensor<3xi64> -// CHECK: %[[T9:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[T1]], %[[T8]]) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<256x256xf32>, tensor<3xi64>) -> tensor -// CHECK: %[[T10:.*]] = "mhlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #mhlo.dot} : (tensor, tensor) -> tensor +// CHECK: %[[T9:.*]] = stablehlo.dynamic_broadcast_in_dim %[[T1]], %[[T8]], dims = [1, 2] : (tensor<256x256xf32>, tensor<3xi64>) -> tensor +// CHECK: %[[T10:.*]] = "stablehlo.dot_general"(%[[T0]], %[[T9]]) {dot_dimension_numbers = #stablehlo.dot} : (tensor, tensor) -> tensor // CHECK: %[[T11:.*]] = tensor.cast %[[T10]] : tensor to tensor // CHECK: %[[T12:.*]] = torch_c.from_builtin_tensor %[[T11]] : tensor -> !torch.vtensor<[?,?,256],f32> // CHECK: return %[[T12]] : !torch.vtensor<[?,?,256],f32> @@ -255,8 +255,8 @@ func.func @torch.aten.matmul$proj(%arg0: !torch.vtensor<[?,?,256],f32>) -> !torc // CHECK-LABEL: func.func @torch.aten.mm$proj( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,256],f32>) -> !torch.vtensor<[?,256],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,256],f32> -> tensor -// CHECK: %[[T1:.*]] = mhlo.constant dense<1.000000e+00> : tensor<256x256xf32> -// CHECK: %[[T2:.*]] = "mhlo.dot"(%[[T0]], %[[T1]]) : (tensor, tensor<256x256xf32>) -> tensor +// CHECK: %[[T1:.*]] = stablehlo.constant dense<1.000000e+00> : tensor<256x256xf32> +// CHECK: %[[T2:.*]] = stablehlo.dot %[[T0]], %[[T1]] : (tensor, tensor<256x256xf32>) -> tensor // CHECK: %[[T3:.*]] = tensor.cast %[[T2]] : tensor to tensor // CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[?,256],f32> // CHECK: return %[[T4]] : !torch.vtensor<[?,256],f32> @@ -284,7 +284,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // CHECK: %[[T_11:.*]] = torch.prim.ListConstruct %[[T_7]], %[[T_5]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_12:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %[[T_13:.*]] = torch.constant.bool false -// CHECK: %[[T_14:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK: %[[T_14:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]]) // CHECK-SAME{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[T_15:.*]] = torch_c.from_builtin_tensor %[[T_14]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_15]] : !torch.vtensor<[?,?,?,?],f32> @@ -321,14 +321,14 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // CHECK: %[[T_6:.*]] = torch.prim.ListConstruct %int3, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_7:.*]] = torch.prim.ListConstruct : () -> !torch.list // CHECK: %false = torch.constant.bool false -// CHECK: %[[T_8:.*]] = mhlo.convolution(%[[T_0]], %[[T_1]]) +// CHECK: %[[T_8:.*]] = stablehlo.convolution(%[[T_0]], %[[T_1]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {stride = [2, 1], pad = [[4, 4], [2, 2]], rhs_dilate = [3, 1]} {batch_group_count = 1 : i64, feature_group_count = 3 : i64} : (tensor, tensor) -> tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_9:.*]] = tensor.dim %[[T_2]], %[[IDX_0]] : tensor // CHECK: %[[T_10:.*]] = arith.index_cast %[[T_9]] : index to i64 // CHECK: %[[VAL_0:.*]] = arith.constant 1 : i64 // CHECK: %[[T_11:.*]] = tensor.from_elements %[[T_10]], %[[VAL_0]], %[[VAL_0]] : tensor<3xi64> -// CHECK: %[[T_12:.*]] = mhlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T_12:.*]] = stablehlo.dynamic_reshape %[[T_2]], %[[T_11]] : (tensor, tensor<3xi64>) -> tensor // CHECK: %[[T_13:.*]] = chlo.broadcast_add %[[T_8]], %[[T_12]] : (tensor, tensor) -> tensor // CHECK: %[[T_14:.*]] = torch_c.from_builtin_tensor %[[T_13]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T_14]] : !torch.vtensor<[?,?,?,?],f32> @@ -360,8 +360,8 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // CHECK: %[[T_2:.*]] = torch_c.to_i64 %int1 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_5:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> -// CHECK: %[[T_6:.*]] = mhlo.convolution(%[[T_0]], %[[T_5]]) +// CHECK: %[[T_5:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32> +// CHECK: %[[T_6:.*]] = stablehlo.convolution(%[[T_0]], %[[T_5]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x9x9xf32> // CHECK: %[[T_7:.*]] = torch_c.from_builtin_tensor %[[T_6]] : tensor<1x4x9x9xf32> -> !torch.vtensor<[1,4,9,9],f32> // CHECK: return %[[T_7]] : !torch.vtensor<[1,4,9,9],f32> @@ -392,8 +392,8 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> -// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_8:.*]] = torch_c.from_builtin_tensor %[[T_7]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_8]] : !torch.vtensor<[1,4,15,15],f32> @@ -426,11 +426,11 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%[[T_1]]) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x4x3x3xf32>) -> tensor<2x4x3x3xf32> -// CHECK: %[[T_7:.*]] = mhlo.convolution(%[[T_0]], %[[T_6]]) +// CHECK: %[[T_6:.*]] = stablehlo.reverse %[[T_1]], dims = [2, 3] : tensor<2x4x3x3xf32> +// CHECK: %[[T_7:.*]] = stablehlo.convolution(%[[T_0]], %[[T_6]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x2x7x7xf32>, tensor<2x4x3x3xf32>) -> tensor<1x4x15x15xf32> -// CHECK: %[[T_8:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T_9:.*]] = "mhlo.pad"(%[[T_7]], %[[T_8]]) {edge_padding_high = dense<[0, 0, 1, 1]> : vector<4xi64>, edge_padding_low = dense<0> : vector<4xi64>, interior_padding = dense<0> : vector<4xi64>} : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> +// CHECK: %[[T_8:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[T_9:.*]] = stablehlo.pad %[[T_7]], %[[T_8]], low = [0, 0, 0, 0], high = [0, 0, 1, 1], interior = [0, 0, 0, 0] : (tensor<1x4x15x15xf32>, tensor) -> tensor<1x4x16x16xf32> // CHECK: %[[T_10:.*]] = torch_c.from_builtin_tensor %[[T_9:.*]] : tensor<1x4x16x16xf32> -> !torch.vtensor<[1,4,16,16],f32> // CHECK: return %[[T_10]] : !torch.vtensor<[1,4,16,16],f32> func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor<[1,2,7,7],f32>, %arg1: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { @@ -462,7 +462,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_3:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T_5:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T_6:.*]] = "mhlo.reverse"(%1) {dimensions = dense<[2, 3]> : tensor<2xi64>} : (tensor<2x2x3x3xf32>) -> tensor<2x2x3x3xf32> +// CHECK: %[[T_6:.*]] = stablehlo.reverse %1, dims = [2, 3] : tensor<2x2x3x3xf32> // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[T_7:.*]] = tensor.dim %[[T_6]], %[[IDX_0]] : tensor<2x2x3x3xf32> // CHECK: %[[T_8:.*]] = arith.index_cast %[[T_7]] : index to i64 @@ -479,11 +479,11 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_15:.*]] = arith.divsi %[[T_8]], %[[T_24]] : i64 // CHECK: %[[T_16:.*]] = arith.muli %[[T_10]], %[[T_24]] : i64 // CHECK: %[[T_17:.*]] = tensor.from_elements %[[T_24]], %[[T_15]], %[[T_10]], %[[T_12]], %[[T_14]] : tensor<5xi64> -// CHECK: %[[T_18:.*]] = mhlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> -// CHECK: %[[T_19:.*]] = "mhlo.transpose"(%[[T_18]]) {permutation = dense<[1, 0, 2, 3, 4]> : tensor<5xi64>} : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> +// CHECK: %[[T_18:.*]] = stablehlo.dynamic_reshape %[[T_6]], %[[T_17]] : (tensor<2x2x3x3xf32>, tensor<5xi64>) -> tensor<2x1x2x3x3xf32> +// CHECK: %[[T_19:.*]] = stablehlo.transpose %[[T_18]], dims = [1, 0, 2, 3, 4] : (tensor<2x1x2x3x3xf32>) -> tensor<1x2x2x3x3xf32> // CHECK: %[[T_20:.*]] = tensor.from_elements %[[T_15]], %[[T_16]], %[[T_12]], %[[T_14]] : tensor<4xi64> -// CHECK: %[[T_21:.*]] = mhlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> -// CHECK: %[[T_22:.*]] = mhlo.convolution(%[[T_0]], %[[T_21]]) +// CHECK: %[[T_21:.*]] = stablehlo.dynamic_reshape %[[T_19]], %[[T_20]] : (tensor<1x2x2x3x3xf32>, tensor<4xi64>) -> tensor<1x4x3x3xf32> +// CHECK: %[[T_22:.*]] = stablehlo.convolution(%[[T_0]], %[[T_21]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[i, o, 0, 1]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<1x4x3x3xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_23:.*]] = torch_c.from_builtin_tensor %[[T_22]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_23]] : !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToMhlo/lit.local.cfg b/test/Conversion/TorchToMhlo/lit.local.cfg index 829a5662f6e6..d4f752cd7104 100644 --- a/test/Conversion/TorchToMhlo/lit.local.cfg +++ b/test/Conversion/TorchToMhlo/lit.local.cfg @@ -1,2 +1,2 @@ -if not config.enable_mhlo: +if not config.enable_stablehlo: config.unsupported = True diff --git a/test/Conversion/TorchToMhlo/pooling.mlir b/test/Conversion/TorchToMhlo/pooling.mlir index 684eb7828864..98805bdd8c29 100644 --- a/test/Conversion/TorchToMhlo/pooling.mlir +++ b/test/Conversion/TorchToMhlo/pooling.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s // ----- @@ -13,11 +13,11 @@ // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int0, %int0 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_5:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_6:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[VAL_7:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ +// CHECK: %[[VAL_6:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_7:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_6]]) ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): -// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor -// CHECK: mhlo.return %[[VAL_10]] : tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: stablehlo.return %[[VAL_10]] : tensor // CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> @@ -45,11 +45,11 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int2, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor -// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): -// CHECK: %[[VAL_10:.*]] = mhlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor -// CHECK: mhlo.return %[[VAL_10]] : tensor +// CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor +// CHECK: stablehlo.return %[[VAL_10]] : tensor // CHECK: }) // CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> @@ -80,7 +80,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T5:.*]] = mhlo.constant dense<-3.40282347E+38> : tensor +// CHECK: %[[T5:.*]] = stablehlo.constant dense<-3.40282347E+38> : tensor // CHECK: %[[C0:.*]] = arith.constant 0 : index // CHECK: %[[DIM:.*]] = tensor.dim %[[T0]], %[[C0]] : tensor // CHECK: %[[T6:.*]] = arith.index_cast %[[DIM]] : index to i64 @@ -93,18 +93,18 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T6]], %[[T7]], %[[T8]] : tensor<3xi64> // CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T7]] : i64 // CHECK: %[[FROM_ELEMENTS_2:.*]] = tensor.from_elements %[[T6]], %[[T9]] : tensor<2xi64> -// CHECK: %[[T10:.*]] = "mhlo.dynamic_iota"(%[[FROM_ELEMENTS_2]]) {iota_dimension = 1 : i64} : (tensor<2xi64>) -> tensor -// CHECK: %[[T11:.*]] = mhlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor -// CHECK: %[[T12:.*]] = mhlo.constant dense<0> : tensor -// CHECK: %[[T13:.*]]:2 = "mhlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ +// CHECK: %[[T10:.*]] = stablehlo.dynamic_iota %[[FROM_ELEMENTS_2]], dim = 1 : (tensor<2xi64>) -> tensor +// CHECK: %[[T11:.*]] = stablehlo.dynamic_reshape %[[T10]], %[[FROM_ELEMENTS]] : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[T12:.*]] = stablehlo.constant dense<0> : tensor +// CHECK: %[[T13:.*]]:2 = "stablehlo.reduce_window"(%[[T0]], %[[T11]], %[[T5]], %[[T12]]) ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor, %[[ARG3:.*]]: tensor, %[[ARG4:.*]]: tensor): -// CHECK: %[[T16:.*]] = mhlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T17:.*]] = mhlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor, tensor -// CHECK: %[[T18:.*]] = mhlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor -// CHECK: %[[T19:.*]] = mhlo.minimum %[[ARG2]], %[[ARG4]] : tensor -// CHECK: %[[T20:.*]] = mhlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor -// CHECK: %[[T21:.*]] = mhlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor -// CHECK: mhlo.return %[[T17]], %[[T21]] : tensor, tensor +// CHECK: %[[T16:.*]] = stablehlo.compare GE, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T17:.*]] = stablehlo.select %[[T16]], %[[ARG1]], %[[ARG3]] : tensor, tensor +// CHECK: %[[T18:.*]] = stablehlo.compare EQ, %[[ARG1]], %[[ARG3]], FLOAT : (tensor, tensor) -> tensor +// CHECK: %[[T19:.*]] = stablehlo.minimum %[[ARG2]], %[[ARG4]] : tensor +// CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor +// CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor +// CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor // CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> @@ -136,13 +136,13 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_2:.*]] = torch.prim.ListConstruct %int3, %int3 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_3:.*]] = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_4:.*]] = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[VAL_5:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_6:.*]] = "mhlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ +// CHECK: %[[VAL_5:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_6:.*]] = "stablehlo.reduce_window"(%[[VAL_1]], %[[VAL_5]]) ({ // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): -// CHECK: %[[IVAL_2:.*]] = mhlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor -// CHECK: mhlo.return %[[IVAL_2]] : tensor +// CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor +// CHECK: stablehlo.return %[[IVAL_2]] : tensor // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_7:.*]] = mhlo.constant dense<1.000000e+00> : tensor +// CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor // CHECK: %[[VAL_9:.*]] = arith.index_cast %[[VAL_8]] : index to i64 @@ -156,14 +156,14 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: %[[VAL_14:.*]] = tensor.dim %[[VAL_1]], %[[IDX_3]] : tensor // CHECK: %[[VAL_15:.*]] = arith.index_cast %[[VAL_14]] : index to i64 // CHECK: %[[VAL_16:.*]] = tensor.from_elements %[[VAL_9]], %[[VAL_11]], %[[VAL_13]], %[[VAL_15]] : tensor<4xi64> -// CHECK: %[[VAL_17:.*]] = "mhlo.dynamic_broadcast_in_dim"(%[[VAL_7]], %[[VAL_16]]) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor, tensor<4xi64>) -> tensor -// CHECK: %[[VAL_18:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[VAL_19:.*]] = "mhlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ +// CHECK: %[[VAL_17:.*]] = stablehlo.dynamic_broadcast_in_dim %[[VAL_7]], %[[VAL_16]], dims = [] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[VAL_18:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[VAL_19:.*]] = "stablehlo.reduce_window"(%[[VAL_17]], %[[VAL_18]]) ({ // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): -// CHECK: %[[IVAL_5:.*]] = mhlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor -// CHECK: mhlo.return %[[IVAL_5]] : tensor +// CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor +// CHECK: stablehlo.return %[[IVAL_5]] : tensor // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor -// CHECK: %[[VAL_20:.*]] = mhlo.divide %[[VAL_6]], %[[VAL_19]] : tensor +// CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -193,14 +193,14 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T2:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[T3:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1]] : (!torch.int, !torch.int) -> !torch.list -// CHECK: %[[T4:.*]] = mhlo.constant dense<0.000000e+00> : tensor -// CHECK: %[[T5:.*]] = "mhlo.reduce_window"(%[[T0]], %[[T4]]) ({ +// CHECK: %[[T4:.*]] = stablehlo.constant dense<0.000000e+00> : tensor +// CHECK: %[[T5:.*]] = "stablehlo.reduce_window"(%[[T0]], %[[T4]]) ({ // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): -// CHECK: %[[T10:.*]] = mhlo.add %[[ARG1]], %[[ARG2]] : tensor -// CHECK: mhlo.return %[[T10]] : tensor +// CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor +// CHECK: stablehlo.return %[[T10]] : tensor // CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor -// CHECK: %[[T6:.*]] = mhlo.constant dense<9> : tensor -// CHECK: %[[T7:.*]] = mhlo.convert %[[T6]] : (tensor) -> tensor +// CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor +// CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?,?],f32> diff --git a/test/Conversion/TorchToMhlo/view_like.mlir b/test/Conversion/TorchToMhlo/view_like.mlir index 70f3570d800a..8a6ec8d7266a 100644 --- a/test/Conversion/TorchToMhlo/view_like.mlir +++ b/test/Conversion/TorchToMhlo/view_like.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-to-mhlo -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func.func @torch.aten.slice.strided$slice_like( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -42,7 +42,7 @@ // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -97,7 +97,7 @@ func.func @torch.aten.slice.strided$slice_like(%arg0: !torch.vtensor<[?,?,?],f32 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T10]], %[[C0_I64_5]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T21]], %[[T18]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[T2]], %[[C1_I64]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<2x65x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<2x65x256xf32> -> !torch.vtensor<[2,65,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[2,65,256],f32> func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[2,65,256],f32> { @@ -152,7 +152,7 @@ func.func @torch.aten.slice.strided.static$slice_like(%arg0: !torch.vtensor<[4,6 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor -> !torch.vtensor<[?,1,?],f32> // CHECK: return %[[T23]] : !torch.vtensor<[?,1,?],f32> func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,1,?],f32> { @@ -207,7 +207,7 @@ func.func @torch.aten.slice.last$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_5]], %[[T10]], %[[C0_I64_5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[T17]], %[[T21]], %[[T19]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_7:.*]] = tensor.from_elements %[[C1_I64]], %[[T2]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T22:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> +// CHECK: %[[T22:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_6, %[[FROM_ELEMENTS]]_7 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x1x256xf32> // CHECK: %[[T23:.*]] = torch_c.from_builtin_tensor %[[T22]] : tensor<4x1x256xf32> -> !torch.vtensor<[4,1,256],f32> // CHECK: return %[[T23]] : !torch.vtensor<[4,1,256],f32> func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,1,256],f32> { @@ -247,7 +247,7 @@ func.func @torch.aten.slice.last.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,?,?],f32> func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) -> !torch.vtensor<[?,?,?],f32> { @@ -287,7 +287,7 @@ func.func @torch.aten.slice.none$slice_like(%arg0: !torch.vtensor<[?,?,?],f32>) // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C0_I64_4]], %[[C0_I64]], %[[C0_I64_4]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_5:.*]] = tensor.from_elements %[[T3]], %[[T7]], %[[T5]] : tensor<3xi64> // CHECK: %[[FROM_ELEMENTS_6:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[C1_I64]] : tensor<3xi64> -// CHECK: %[[T8:.*]] = mhlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> +// CHECK: %[[T8:.*]] = stablehlo.real_dynamic_slice %[[T0]], %[[FROM_ELEMENTS]], %[[FROM_ELEMENTS]]_5, %[[FROM_ELEMENTS]]_6 : (tensor<4x65x256xf32>, tensor<3xi64>, tensor<3xi64>, tensor<3xi64>) -> tensor<4x33x256xf32> // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor<4x33x256xf32> -> !torch.vtensor<[4,33,256],f32> // CHECK: return %[[T9]] : !torch.vtensor<[4,33,256],f32> func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,33,256],f32> { @@ -313,8 +313,8 @@ func.func @torch.aten.slice.none.static$slice_like(%arg0: !torch.vtensor<[4,65,2 // CHECK: %[[T5:.*]] = arith.muli %[[T4]], %[[T3]] : i64 // CHECK: %[[T6:.*]] = arith.index_cast %[[T5]] : i64 to index // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]] : tensor<2xi64> -// CHECK: %[[T7:.*]] = mhlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> -// CHECK: %[[T8:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor, tensor<2xi64>) -> tensor +// CHECK: %[[T7:.*]] = stablehlo.compute_reshape_shape %[[T6]], %[[FROM_ELEMENTS]] : (index, tensor<2xi64>) -> tensor<2xi64> +// CHECK: %[[T8:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T7]] : (tensor, tensor<2xi64>) -> tensor // CHECK: %[[T9:.*]] = torch_c.from_builtin_tensor %[[T8]] : tensor -> !torch.vtensor<[?,224],f32> // CHECK: return %[[T9]] : !torch.vtensor<[?,224],f32> func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,224],f32> { @@ -346,8 +346,8 @@ func.func @torch.aten.view$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[T9:.*]] = arith.muli %[[T8]], %[[T5]] : i64 // CHECK: %[[T10:.*]] = arith.index_cast %[[T9]] : i64 to index // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T2]], %[[T3]], %[[T4]], %[[T5]] : tensor<4xi64> -// CHECK: %[[T11:.*]] = mhlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> -// CHECK: %[[T12:.*]] = mhlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T11:.*]] = stablehlo.compute_reshape_shape %[[T10]], %[[FROM_ELEMENTS]] : (index, tensor<4xi64>) -> tensor<4xi64> +// CHECK: %[[T12:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[T11]] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[T13:.*]] = torch_c.from_builtin_tensor %[[T12]] : tensor -> !torch.vtensor<[?,120,4,64],f32> // CHECK: return %[[T13]] : !torch.vtensor<[?,120,4,64],f32> func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,120,4,64],f32> { @@ -367,7 +367,7 @@ func.func @torch.aten.reshape$basic(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> ! // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[],f32> -> tensor // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[T1:.*]] = torch.prim.ListConstruct %[[INT1]] : (!torch.int) -> !torch.list -// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor) -> tensor<1xf32> +// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor) -> tensor<1xf32> // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor<1xf32> -> !torch.vtensor<[1],f32> // CHECK: return %[[T3]] : !torch.vtensor<[1],f32> func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[1],f32> { @@ -383,7 +383,7 @@ func.func @torch.aten.view$to_rank1(%arg0: !torch.vtensor<[],f32>) -> !torch.vte // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { // CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[1],f32> -> tensor<1xf32> // CHECK: %[[T1:.*]] = torch.prim.ListConstruct : () -> !torch.list -// CHECK: %[[T2:.*]] = mhlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor +// CHECK: %[[T2:.*]] = stablehlo.reshape %[[T0]] : (tensor<1xf32>) -> tensor // CHECK: %[[T3:.*]] = torch_c.from_builtin_tensor %[[T2]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[T3]] : !torch.vtensor<[],f32> func.func @torch.aten.view$to_rank0(%arg0: !torch.vtensor<[1],f32>) -> !torch.vtensor<[],f32> { @@ -425,7 +425,7 @@ func.func @torch.aten.squeeze.dim$0$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32 // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,1,?],f32> func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,?,1,?],f32> { @@ -453,7 +453,7 @@ func.func @torch.aten.squeeze.dim$1(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> ! // CHECK: %[[DIM_2:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<4xi64> -// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<4xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?],f32> func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32>) -> !torch.vtensor<[?,1,?,?],f32> { @@ -477,7 +477,7 @@ func.func @torch.aten.squeeze.dim$from_end(%arg0: !torch.vtensor<[?,1,?,1,?],f32 // CHECK: %[[DIM_1:.*]] = tensor.dim %[[T0]], %[[C4]] : tensor<2x1x2x1x2xf32> // CHECK: %[[T3:.*]] = arith.index_cast %[[DIM_1]] : index to i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]] : tensor<3xi64> -// CHECK: %[[T4:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> +// CHECK: %[[T4:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor<2x1x2x1x2xf32>, tensor<3xi64>) -> tensor<2x2x2xf32> // CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor<2x2x2xf32> -> !torch.vtensor<[2,2,2],f32> // CHECK: return %[[T5]] : !torch.vtensor<[2,2,2],f32> func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> !torch.vtensor<[2,2,2],f32> { @@ -505,7 +505,7 @@ func.func @torch.aten.squeeze$static(%arg0: !torch.vtensor<[2,1,2,1,2],f32>) -> // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[C1_I64]], %[[T1]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor<1x?x?x?x?xf32> // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor<1x?x?x?x?xf32> -> !torch.vtensor<[1,?,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[1,?,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1,?,?,?,?],f32> { @@ -534,7 +534,7 @@ func.func @torch.aten.unsqueeze$dim$0(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[C1_I64]], %[[T2]], %[[T3]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,1,?,?,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,1,?,?,?],f32> func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,1,?,?,?],f32> { @@ -563,7 +563,7 @@ func.func @torch.aten.unsqueeze$dim$1(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // CHECK: %[[T4:.*]] = arith.index_cast %[[DIM_2]] : index to i64 // CHECK: %[[C1_I64:.*]] = arith.constant 1 : i64 // CHECK: %[[FROM_ELEMENTS:.*]] = tensor.from_elements %[[T1]], %[[T2]], %[[T3]], %[[C1_I64]], %[[T4]] : tensor<5xi64> -// CHECK: %[[T5:.*]] = mhlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor +// CHECK: %[[T5:.*]] = stablehlo.dynamic_reshape %[[T0]], %[[FROM_ELEMENTS]] : (tensor, tensor<5xi64>) -> tensor // CHECK: %[[T6:.*]] = torch_c.from_builtin_tensor %[[T5]] : tensor -> !torch.vtensor<[?,?,?,1,?],f32> // CHECK: return %[[T6]] : !torch.vtensor<[?,?,?,1,?],f32> func.func @torch.aten.unsqueeze$from_end(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,1,?],f32> { diff --git a/test/lit.site.cfg.py.in b/test/lit.site.cfg.py.in index 339975f3e636..f7ac86747277 100644 --- a/test/lit.site.cfg.py.in +++ b/test/lit.site.cfg.py.in @@ -17,7 +17,7 @@ config.llvm_exe_ext = "@EXEEXT@" config.lit_tools_dir = "@LLVM_LIT_TOOLS_DIR@" config.python_executable = "@Python3_EXECUTABLE@" config.enable_jit_ir_importer = @TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@ -config.enable_mhlo = @TORCH_MLIR_ENABLE_MHLO@ +config.enable_stablehlo = @TORCH_MLIR_ENABLE_STABLEHLO@ import lit.llvm lit.llvm.initialize(lit_config, config) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 86b4060b8491..d284e27b42b7 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -268,7 +268,7 @@ gentbl_cc_library( ( [ "-gen-pass-decls", - "-DTORCH_MLIR_ENABLE_MHLO", + "-DTORCH_MLIR_ENABLE_STABLEHLO", ], "include/torch-mlir/Conversion/Passes.h.inc", ), @@ -434,13 +434,13 @@ cc_library( ) cc_library( - name = "TorchMLIRTorchToMhlo", + name = "TorchMLIRTorchToStablehlo", srcs = glob([ "lib/Conversion/*.h", - "lib/Conversion/TorchToMhlo/*.h", - "lib/Conversion/TorchToMhlo/*.cpp", + "lib/Conversion/TorchToStablehlo/*.h", + "lib/Conversion/TorchToStablehlo/*.cpp", ]), - hdrs = glob(["include/torch-mlir/Conversion/TorchToMhlo/*.h"]), + hdrs = glob(["include/torch-mlir/Conversion/TorchToStablehlo/*.h"]), strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPassesIncGen", @@ -465,8 +465,8 @@ cc_library( ":TorchMLIRTorchConversionToMLProgram", ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", - ":TorchMLIRTorchToMhlo", ":TorchMLIRTorchToSCF", + ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTosa", ], @@ -489,8 +489,8 @@ cc_library( ":TorchMLIRTorchPasses", ":TorchMLIRTorchToArith", ":TorchMLIRTorchToLinalg", - ":TorchMLIRTorchToMhlo", ":TorchMLIRTorchToSCF", + ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", ":TorchMLIRTorchToTosa", "@llvm-project//mlir:ConversionPasses", diff --git a/utils/bazel/torch-mlir-overlay/test/BUILD.bazel b/utils/bazel/torch-mlir-overlay/test/BUILD.bazel index 2db2a775116c..d29391305cc0 100644 --- a/utils/bazel/torch-mlir-overlay/test/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/test/BUILD.bazel @@ -23,7 +23,7 @@ expand_template( # All disabled, but required to substituted because they are not in quotes. "@MLIR_ENABLE_BINDINGS_PYTHON@": "0", "@TORCH_MLIR_ENABLE_JIT_IR_IMPORTER@": "0", - "@TORCH_MLIR_ENABLE_MHLO@": "0", + "@TORCH_MLIR_ENABLE_STABLEHLO@": "0", }, template = "lit.site.cfg.py.in", )