diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 82331de84dd1f..590d1e5c090e1 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -15,20 +15,27 @@ # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( - LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig + LazyTensorCoreTestConfig, + LinalgOnTensorsBackendTestConfig, + MhloBackendTestConfig, + NativeTorchTestConfig, + TorchScriptTestConfig, + TosaBackendTestConfig, + EagerModeTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET +from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -36,6 +43,7 @@ def _get_argparse(): help=f''' Meaning of options: "refbackend": run through torch-mlir's RefBackend. +"mhlo": run through torch-mlir's default MHLO backend. "tosa": run through torch-mlir's default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). @@ -78,6 +86,9 @@ def main(): if args.config == 'tosa': config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET + if args.config == 'mhlo': + config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) + xfail_set = all_test_unique_names - MHLO_PASS_SET elif args.config == 'native_torch': config = NativeTorchTestConfig() xfail_set = {} diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index f26b061ca6c0f..c20d137903e06 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -21,6 +21,121 @@ "Matmul_vecmat" } +MHLO_PASS_SET = { + "AvgPool2dStaticModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ReturnThreeTensorFloat32_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "SqueezeModule_static", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceWholeTensorModule_basic", + "ReturnTwoTensorF32I64_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_2d", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool2dWithIndicesStaticModule_basic", + "MmDagModule_basic", + "MmModule_basic", + "MmModule_chained", + "MaxPool2dStaticModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "UnsafeViewExpandModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceSumDimIntListFloatModule_basic", + "ReduceSumDimIntListIntModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + "RepeatModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeExpandModule_basic", + "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "BaddbmmStaticModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NumToTensorIntModule_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNStaticModule_basic", + "NumpyTRankNDynamicModule_basic", + "TModuleRank2_basic", + "TensorLiteralModule_basic", + "TensorOpaqueLiteralModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "Permute0RankModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", +} + # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 014d6998e0d08..98f1acb75e054 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,6 +9,9 @@ #include "torch-mlir/Conversion/Passes.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#endif // TORCH_MLIR_ENABLE_MHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" @@ -25,4 +28,11 @@ namespace { #include "torch-mlir/Conversion/Passes.h.inc" } // end namespace -void mlir::torch::registerConversionPasses() { ::registerPasses(); } +void mlir::torch::registerConversionPasses() { + ::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_MHLO + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::mhlo::createLegalizeHloToLinalgPass(); + }); +#endif // TORCH_MLIR_ENABLE_MHLO +} diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 3e9c3c41bf0d9..b3468cb7be4d2 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo DEPENDS MhloDialect ChloDialect + MhloToLinalg + MLIRMhloPassIncGen TorchMLIRConversionPassIncGen LINK_COMPONENTS @@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MLIRPass MhloDialect ChloDialect + MhloToLinalg TorchMLIRTorchDialect ) diff --git a/python/torch_mlir_e2e_test/mhlo_backends/__init__.py b/python/torch_mlir_e2e_test/mhlo_backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/torch_mlir_e2e_test/mhlo_backends/abc.py b/python/torch_mlir_e2e_test/mhlo_backends/abc.py new file mode 100644 index 0000000000000..8fc51ac00f7ae --- /dev/null +++ b/python/torch_mlir_e2e_test/mhlo_backends/abc.py @@ -0,0 +1,49 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import abc +from typing import TypeVar + +import torch + +from torch_mlir.ir import Module + +# A type shared between the result of `MhloBackend.compile` and the +# input to `MhloBackend.load`. Each backend will likely have a +# different definition of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class MhloBackend(abc.ABC): + """The interface to an MHLO backend. + + Backends are recommended to raise meaningful exceptions in case of error, + ideally with easy reproduction instructions. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the MHLO backend contract + (see the VerifyMhloBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py new file mode 100644 index 0000000000000..25896e0a0043d --- /dev/null +++ b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py @@ -0,0 +1,45 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import MhloBackend + +__all__ = [ + "LinalgOnTensorsMhloBackend", +] + +class LinalgOnTensorsMhloBackend(MhloBackend): + """Main entry-point for the linalg-on-tensors based MHLO backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the MHLO backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the MHLO + dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + run_pipeline_with_repro_report( + imported_module, + "func.func(hlo-legalize-to-linalg)", + "Lowering MLIR-HLO to Linalg-on-Tensors") + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py index 63d9a733940c5..a7118c0eff986 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py @@ -7,5 +7,6 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig +from .mhlo_backend import MhloBackendTestConfig from .tosa_backend import TosaBackendTestConfig from .eager_mode import EagerModeTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py b/python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py new file mode 100644 index 0000000000000..5b739222b1ba0 --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py @@ -0,0 +1,60 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import sys +from typing import Any +from io import StringIO +import os +import tempfile + +import numpy as np +import torch + +from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, + convert_torchscript_module_to_torch_backend_contract_mlir, +) + + +class MhloBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with linalg-on-tensors. + + This class handles all the common lowering that torch-mlir does before + reaching the linalg-on-tensors abstraction level. + """ + def __init__(self, backend: MhloBackend): + super().__init__() + self.backend = backend + + def compile(self, program: torch.nn.Module) -> Any: + + module = convert_torchscript_module_to_torch_backend_contract_mlir( + program) + + run_pipeline_with_repro_report( + module, + "torch-backend-to-mhlo-backend-pipeline", + "Lower Torch Backend IR -> MHLO Backend IR") + + return self.backend.compile(module) + + + + def run(self, artifact: Any, trace: Trace) -> Trace: + backend_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = recursively_convert_to_numpy(item.inputs) + outputs = getattr(backend_module, item.symbol)(*numpy_inputs) + output = recursively_convert_from_numpy(outputs) + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result