From 355720c3ed557be470a79d56fdaefeb1205252d8 Mon Sep 17 00:00:00 2001 From: Tanyo Kwok Date: Tue, 23 Aug 2022 15:04:57 +0800 Subject: [PATCH] [MHLO] Init end to end unit tests --- .github/workflows/buildAndTest.yml | 6 + e2e_testing/torchscript/main.py | 17 ++- e2e_testing/torchscript/xfail_sets.py | 132 ++++++++++++++++++ lib/Conversion/Passes.cpp | 12 +- lib/Conversion/TorchToMhlo/Basic.cpp | 36 ++++- lib/Conversion/TorchToMhlo/CMakeLists.txt | 3 + lib/Conversion/TorchToMhlo/Reduction.cpp | 3 + lib/Conversion/TorchToMhlo/ViewLike.cpp | 18 +++ python/torch_mlir/__init__.py | 2 - .../mhlo_backends/__init__.py | 0 .../torch_mlir_e2e_test/mhlo_backends/abc.py | 49 +++++++ .../mhlo_backends/linalg_on_tensors.py | 45 ++++++ .../torchscript/configs/__init__.py | 1 + .../torchscript/configs/mhlo_backend.py | 50 +++++++ 14 files changed, 367 insertions(+), 7 deletions(-) create mode 100644 python/torch_mlir_e2e_test/mhlo_backends/__init__.py create mode 100644 python/torch_mlir_e2e_test/mhlo_backends/abc.py create mode 100644 python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py create mode 100644 python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index 94cdc64186f7b..1af6cdc0fef1d 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -167,6 +167,12 @@ jobs: export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" python -m e2e_testing.torchscript.main --config=eager_mode -v + - name: Run mhlo e2e integration tests + if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} + run: | + export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir" + python -m e2e_testing.torchscript.main --config=mhlo -v + - name: Run tosa e2e integration tests if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }} run: | diff --git a/e2e_testing/torchscript/main.py b/e2e_testing/torchscript/main.py index 82331de84dd1f..590d1e5c090e1 100644 --- a/e2e_testing/torchscript/main.py +++ b/e2e_testing/torchscript/main.py @@ -15,20 +15,27 @@ # Available test configs. from torch_mlir_e2e_test.torchscript.configs import ( - LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig + LazyTensorCoreTestConfig, + LinalgOnTensorsBackendTestConfig, + MhloBackendTestConfig, + NativeTorchTestConfig, + TorchScriptTestConfig, + TosaBackendTestConfig, + EagerModeTestConfig ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend -from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET +from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET # Import tests to register them in the global registry. from torch_mlir_e2e_test.test_suite import register_all_tests register_all_tests() def _get_argparse(): - config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core'] + config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core'] parser = argparse.ArgumentParser(description='Run torchscript e2e tests.') parser.add_argument('-c', '--config', choices=config_choices, @@ -36,6 +43,7 @@ def _get_argparse(): help=f''' Meaning of options: "refbackend": run through torch-mlir's RefBackend. +"mhlo": run through torch-mlir's default MHLO backend. "tosa": run through torch-mlir's default TOSA backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). @@ -78,6 +86,9 @@ def main(): if args.config == 'tosa': config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET + if args.config == 'mhlo': + config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend()) + xfail_set = all_test_unique_names - MHLO_PASS_SET elif args.config == 'native_torch': config = NativeTorchTestConfig() xfail_set = {} diff --git a/e2e_testing/torchscript/xfail_sets.py b/e2e_testing/torchscript/xfail_sets.py index 7d495603d5351..2a58e0d121d6f 100644 --- a/e2e_testing/torchscript/xfail_sets.py +++ b/e2e_testing/torchscript/xfail_sets.py @@ -21,6 +21,138 @@ "Matmul_vecmat" } +MHLO_PASS_SET = { + "FlattenStaticModule_basic", + "FlattenRank0Module_basic", + "TensorsConcatNegativeDimModule_basic", + "NumelModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", + "SqueezeModule_allUnitDim", + "SqueezeDimModule_unitDim", + "MeanModule_basic", + "MeanDynamicSizesModule_basic", + "MeanDimEmptyDimModule_basic", + "NumToTensorFloatModule_basic", + "AtenToDeviceModule_basic", + "AvgPool2dStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Convolution2DStaticModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseBinaryStaticShapeModule_basic", + "ReturnThreeTensorFloat32_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnTrueModule_basic", + "BoolTensorReturnMixedModule_basic", + "SqueezeModule_static", + "TModuleRank1_basic", + "TModuleRank0_basic", + "ElementwiseToDtypeIdentityModule_basic", + "View1DFoldModule_basic", + "UnsafeView1DFoldModule_basic", + "SqueezeDimModule_static", + "SqueezeDimModule_identity", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceStartEqEndModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceWholeTensorModule_basic", + "ReturnTwoTensorF32I64_basic", + "Matmul4dStatic_basic", + "Matmul_dot", + "Matmul_2d", + "Matmul_matvec", + "Matmul_vecmat", + "MaxPool2dWithIndicesStaticModule_basic", + "MmDagModule_basic", + "MmModule_basic", + "MmModule_chained", + "MaxPool2dStaticModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleFalsePinMemory_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleInt_basic", + "OnesModuleFloat_basic", + "OnesModuleFalsePinMemory_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewOnesModuleDefaultDtype_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewOnesModuleFloat2D_basic", + "NewOnesModuleFloat3D_basic", + "NewOnesModuleFalsePinMemory_basic", + "DropoutEvalIntModule_basic", + "DropoutEvalFloatModule_basic", + "ContiguousModule_basic", + "DropoutModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "UnsafeViewExpandModule_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + "ReduceSumDimIntListFloatModule_basic", + "ReduceSumDimIntListIntModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + "RepeatModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeExpandModule_basic", + "TestMultipleTensorReturn_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", + "BaddbmmStaticModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NumToTensorIntModule_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNStaticModule_basic", + "NumpyTRankNDynamicModule_basic", + "TModuleRank2_basic", + "TensorLiteralModule_basic", + "TensorsConcatModule_basic", + "TensorOpaqueLiteralModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "Permute0RankModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", +} + # Write the TOSA set as a "passing" set as it is very early in development # and very few tests work yet. TOSA_PASS_SET = { diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 014d6998e0d08..98f1acb75e054 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -9,6 +9,9 @@ #include "torch-mlir/Conversion/Passes.h" +#ifdef TORCH_MLIR_ENABLE_MHLO +#include "mlir-hlo/Dialect/mhlo/transforms/passes.h" +#endif // TORCH_MLIR_ENABLE_MHLO #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" @@ -25,4 +28,11 @@ namespace { #include "torch-mlir/Conversion/Passes.h.inc" } // end namespace -void mlir::torch::registerConversionPasses() { ::registerPasses(); } +void mlir::torch::registerConversionPasses() { + ::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_MHLO + ::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> { + return mlir::mhlo::createLegalizeHloToLinalgPass(); + }); +#endif // TORCH_MLIR_ENABLE_MHLO +} diff --git a/lib/Conversion/TorchToMhlo/Basic.cpp b/lib/Conversion/TorchToMhlo/Basic.cpp index 2e65d3e090d37..bf3e073c79afb 100644 --- a/lib/Conversion/TorchToMhlo/Basic.cpp +++ b/lib/Conversion/TorchToMhlo/Basic.cpp @@ -977,8 +977,41 @@ LogicalResult ConvertAtenOp::matchAndRewrite( v = mhlo::promoteType(rewriter, v, outType); } + size_t posDim = toPositiveDim(dim, outType.getRank()); rewriter.replaceOpWithNewOp( - op, ValueRange(builtinTensors), static_cast(dim)); + op, ValueRange(builtinTensors), posDim); + return success(); +} +} // namespace + +// AtenNumelOp +namespace { +template <> +LogicalResult ConvertAtenOp::matchAndRewrite( + AtenNumelOp op, + OpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const { + auto self = adaptor.self(); + auto selfTy = self.getType().dyn_cast(); + size_t rank = selfTy.getRank(); + + Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits); + auto loc = op->getLoc(); + Value numel = + rewriter.create(loc, rewriter.getIntegerAttr(intType, 1)); + for (size_t d = 0 ; d < rank; ++ d) { + Value dimSize = rewriter.create( + loc, intType, rewriter.create(loc, self, d)); + numel = rewriter.create(loc, numel, dimSize); + } + + auto outTy = getTypeConverter()->convertType(op.getType()); + if (outTy != numel.getType()) { + rewriter.replaceOpWithNewOp( + op, outTy, numel); + } else { + rewriter.replaceOp(op, numel); + } return success(); } } // namespace @@ -1067,5 +1100,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality( INSERT_ATENOP_PATTERN(AtenBatchNormOp); INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp); + INSERT_ATENOP_PATTERN(AtenNumelOp); #undef INSERT_ATENOP_PATTERN } diff --git a/lib/Conversion/TorchToMhlo/CMakeLists.txt b/lib/Conversion/TorchToMhlo/CMakeLists.txt index 3e9c3c41bf0d9..b3468cb7be4d2 100644 --- a/lib/Conversion/TorchToMhlo/CMakeLists.txt +++ b/lib/Conversion/TorchToMhlo/CMakeLists.txt @@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo DEPENDS MhloDialect ChloDialect + MhloToLinalg + MLIRMhloPassIncGen TorchMLIRConversionPassIncGen LINK_COMPONENTS @@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo MLIRPass MhloDialect ChloDialect + MhloToLinalg TorchMLIRTorchDialect ) diff --git a/lib/Conversion/TorchToMhlo/Reduction.cpp b/lib/Conversion/TorchToMhlo/Reduction.cpp index 65927cc4c2036..2f99e79f73c22 100644 --- a/lib/Conversion/TorchToMhlo/Reduction.cpp +++ b/lib/Conversion/TorchToMhlo/Reduction.cpp @@ -490,6 +490,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) { return rewriter.notifyMatchFailure(op, "non-int dim list unsupported"); } + if (inputDims.size() == 0) { + inputDims = llvm::to_vector<4>(llvm::seq(0, inputTy.getRank())); + } for (auto d : inputDims) { d = toPositiveDim(d, inputTy.getRank()); diff --git a/lib/Conversion/TorchToMhlo/ViewLike.cpp b/lib/Conversion/TorchToMhlo/ViewLike.cpp index 0502694cd1d64..b6cf840be699e 100644 --- a/lib/Conversion/TorchToMhlo/ViewLike.cpp +++ b/lib/Conversion/TorchToMhlo/ViewLike.cpp @@ -256,6 +256,14 @@ class ConvertAtenViewOp : public OpConversionPattern { numel = rewriter.create(loc, rewriter.getIndexType(), numel); + if (dimSizes.size() == 0) { + rewriter.replaceOpWithNewOp( + op, + OpConversionPattern::getTypeConverter()->convertType( + op.getType()), + adaptor.self()); + return success(); + } Value mhloShape = rewriter.create(loc, dimSizes); Value computedShape = rewriter.create( loc, mhloShape.getType(), numel, mhloShape); @@ -310,6 +318,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (dSize != 1) dims.push_back(r); } + if (dims.size() == 0) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) @@ -354,6 +367,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector dims(rank); std::iota(dims.begin(), dims.end(), 0); dims.erase(dims.begin() + dim); + if (dims.size() == 0) { + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getType()), self); + return success(); + } auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims); if (failed(newDimSizesInfo)) return rewriter.notifyMatchFailure( diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 2a716d3927cea..a03fc2082e9f6 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -202,7 +202,6 @@ def compile(model: torch.nn.Module, scripted = torch.jit.trace(model, tuple(example_args_for_trace)) else: scripted = torch.jit.script(model) - # Convert all concrete inputs to TensorPlaceholder's, for consistency. arg_placeholders = [] for arg in example_args: @@ -240,7 +239,6 @@ def compile(model: torch.nn.Module, """) from None finally: sys.stderr = original_stderr - if output_type == OutputType.RAW: return mb.module diff --git a/python/torch_mlir_e2e_test/mhlo_backends/__init__.py b/python/torch_mlir_e2e_test/mhlo_backends/__init__.py new file mode 100644 index 0000000000000..e69de29bb2d1d diff --git a/python/torch_mlir_e2e_test/mhlo_backends/abc.py b/python/torch_mlir_e2e_test/mhlo_backends/abc.py new file mode 100644 index 0000000000000..8fc51ac00f7ae --- /dev/null +++ b/python/torch_mlir_e2e_test/mhlo_backends/abc.py @@ -0,0 +1,49 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import abc +from typing import TypeVar + +import torch + +from torch_mlir.ir import Module + +# A type shared between the result of `MhloBackend.compile` and the +# input to `MhloBackend.load`. Each backend will likely have a +# different definition of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class MhloBackend(abc.ABC): + """The interface to an MHLO backend. + + Backends are recommended to raise meaningful exceptions in case of error, + ideally with easy reproduction instructions. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the MHLO backend contract + (see the VerifyMhloBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py new file mode 100644 index 0000000000000..25896e0a0043d --- /dev/null +++ b/python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py @@ -0,0 +1,45 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import MhloBackend + +__all__ = [ + "LinalgOnTensorsMhloBackend", +] + +class LinalgOnTensorsMhloBackend(MhloBackend): + """Main entry-point for the linalg-on-tensors based MHLO backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the MHLO backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the MHLO + dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + run_pipeline_with_repro_report( + imported_module, + "func.func(hlo-legalize-to-linalg)", + "Lowering MLIR-HLO to Linalg-on-Tensors") + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py index 63d9a733940c5..a7118c0eff986 100644 --- a/python/torch_mlir_e2e_test/torchscript/configs/__init__.py +++ b/python/torch_mlir_e2e_test/torchscript/configs/__init__.py @@ -7,5 +7,6 @@ from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig from .torchscript import TorchScriptTestConfig +from .mhlo_backend import MhloBackendTestConfig from .tosa_backend import TosaBackendTestConfig from .eager_mode import EagerModeTestConfig diff --git a/python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py b/python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py new file mode 100644 index 0000000000000..112343659348b --- /dev/null +++ b/python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py @@ -0,0 +1,50 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from typing import Any + +import torch +import torch_mlir + +from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend +from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, +) + + +class MhloBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with linalg-on-tensors. + + This class handles all the common lowering that torch-mlir does before + reaching the linalg-on-tensors abstraction level. + """ + def __init__(self, backend: MhloBackend): + super().__init__() + self.backend = backend + + def compile(self, program: torch.nn.Module) -> Any: + example_args = convert_annotations_to_placeholders(program.forward) + module = torch_mlir.compile( + program, example_args, output_type="mhlo") + + return self.backend.compile(module) + + + + def run(self, artifact: Any, trace: Trace) -> Trace: + backend_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = recursively_convert_to_numpy(item.inputs) + outputs = getattr(backend_module, item.symbol)(*numpy_inputs) + output = recursively_convert_from_numpy(outputs) + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result