Skip to content

Commit

Permalink
[MHLO] Init end-to-end unit tests
Browse files Browse the repository at this point in the history
See RFC #999

Co-authored-by: Bairen Yi [email protected]
Co-authored-by: Jiawei Wu [email protected]
Co-authored-by: Tianyou Guo [email protected]
Co-authored-by: Xu Yan [email protected]
Co-authored-by: Ziheng Jiang [email protected]
  • Loading branch information
Tanyo Kwok committed Aug 14, 2022
1 parent 41aa562 commit 88a46e7
Show file tree
Hide file tree
Showing 9 changed files with 298 additions and 4 deletions.
17 changes: 14 additions & 3 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@

# Available test configs.
from torch_mlir_e2e_test.torchscript.configs import (
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
NativeTorchTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
EagerModeTestConfig
)

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
default='refbackend',
help=f'''
Meaning of options:
"refbackend": run through torch-mlir's RefBackend.
"mhlo": run through torch-mlir's default MHLO backend.
"tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -78,6 +86,9 @@ def main():
if args.config == 'tosa':
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == 'mhlo':
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
xfail_set = {}
Expand Down
115 changes: 115 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,121 @@
"Matmul_vecmat"
}

MHLO_PASS_SET = {
"AvgPool2dStaticModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"SqueezeModule_static",
"TModuleRank1_basic",
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
"ReturnTwoTensorF32I64_basic",
"Matmul4dStatic_basic",
"Matmul_dot",
"Matmul_2d",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dWithIndicesStaticModule_basic",
"MmDagModule_basic",
"MmModule_basic",
"MmModule_chained",
"MaxPool2dStaticModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleFalsePinMemory_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleFalsePinMemory_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"UnsafeViewExpandModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmStaticModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NarrowHorizontalTest2_basic",
"NarrowHorizontalTest_basic",
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"TModuleRank2_basic",
"TensorLiteralModule_basic",
"TensorOpaqueLiteralModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"Permute0RankModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
Expand Down
12 changes: 11 additions & 1 deletion lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "torch-mlir/Conversion/Passes.h"

#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
Expand All @@ -25,4 +28,11 @@ namespace {
#include "torch-mlir/Conversion/Passes.h.inc"
} // end namespace

void mlir::torch::registerConversionPasses() { ::registerPasses(); }
void mlir::torch::registerConversionPasses() {
::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
}
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
DEPENDS
MhloDialect
ChloDialect
MhloToLinalg
MLIRMhloPassIncGen
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Expand All @@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
MLIRPass
MhloDialect
ChloDialect
MhloToLinalg
TorchMLIRTorchDialect
)

Expand Down
Empty file.
49 changes: 49 additions & 0 deletions python/torch_mlir_e2e_test/mhlo_backends/abc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import abc
from typing import TypeVar

import torch

from torch_mlir.ir import Module

# A type shared between the result of `MhloBackend.compile` and the
# input to `MhloBackend.load`. Each backend will likely have a
# different definition of this type.
CompiledArtifact = TypeVar('CompiledArtifact')

# A wrapper around a backend-specific loaded program representation
# that uniformly translates the `x.method(...)` interface expected of
# Torch modules into appropriate lower-level operations.
Invoker = TypeVar('Invoker')


class MhloBackend(abc.ABC):
"""The interface to an MHLO backend.
Backends are recommended to raise meaningful exceptions in case of error,
ideally with easy reproduction instructions.
"""
@abc.abstractmethod
def compile(self, module: Module) -> CompiledArtifact:
"""Compile the provided MLIR module into a compiled artifact.
The module adheres to the MHLO backend contract
(see the VerifyMhloBackendContract pass).
The compiled artifact can be any type, but must be correctly
interpreted by the `load` method.
"""

@abc.abstractmethod
def load(self, artifact: CompiledArtifact) -> Invoker:
"""Load the compiled artifact into a uniformly invokable form.
The compiled artifact is the result of a previous call to `compile`.
See the description of `Invoker` for the requirements on the returned
type.
"""
45 changes: 45 additions & 0 deletions python/torch_mlir_e2e_test/mhlo_backends/linalg_on_tensors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from torch_mlir.ir import *
from torch_mlir.passmanager import *
from torch_mlir.compiler_utils import run_pipeline_with_repro_report

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend

from .abc import MhloBackend

__all__ = [
"LinalgOnTensorsMhloBackend",
]

class LinalgOnTensorsMhloBackend(MhloBackend):
"""Main entry-point for the linalg-on-tensors based MHLO backend.
This currently uses the linalg-on-tensors RefBackend for actual execution.
"""
def __init__(self):
super().__init__()
self.refbackend = RefBackendLinalgOnTensorsBackend()

def compile(self, imported_module: Module):
"""Compiles an imported module that satisfied the MHLO backend contract.
Args:
imported_module: The MLIR module consisting of funcs in the MHLO
dialect.
Returns:
An opaque, backend specific compiled artifact object that can be
passed to `load`.
"""
run_pipeline_with_repro_report(
imported_module,
"func.func(hlo-legalize-to-linalg)",
"Lowering MLIR-HLO to Linalg-on-Tensors")
return self.refbackend.compile(imported_module)

def load(self, module):
"""Loads a compiled artifact into the runtime."""
return self.refbackend.load(module)
1 change: 1 addition & 0 deletions python/torch_mlir_e2e_test/torchscript/configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig
from .native_torch import NativeTorchTestConfig
from .torchscript import TorchScriptTestConfig
from .mhlo_backend import MhloBackendTestConfig
from .tosa_backend import TosaBackendTestConfig
from .eager_mode import EagerModeTestConfig
60 changes: 60 additions & 0 deletions python/torch_mlir_e2e_test/torchscript/configs/mhlo_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

import sys
from typing import Any
from io import StringIO
import os
import tempfile

import numpy as np
import torch

from torch_mlir_e2e_test.mhlo_backends.abc import MhloBackend
from torch_mlir_e2e_test.torchscript.framework import TestConfig, Trace, TraceItem
from torch_mlir.compiler_utils import run_pipeline_with_repro_report
from .utils import (
recursively_convert_to_numpy,
recursively_convert_from_numpy,
convert_torchscript_module_to_torch_backend_contract_mlir,
)


class MhloBackendTestConfig(TestConfig):
"""Base class for TestConfig's that are implemented with linalg-on-tensors.
This class handles all the common lowering that torch-mlir does before
reaching the linalg-on-tensors abstraction level.
"""
def __init__(self, backend: MhloBackend):
super().__init__()
self.backend = backend

def compile(self, program: torch.nn.Module) -> Any:

module = convert_torchscript_module_to_torch_backend_contract_mlir(
program)

run_pipeline_with_repro_report(
module,
"torch-backend-to-mhlo-backend-pipeline",
"Lower Torch Backend IR -> MHLO Backend IR")

return self.backend.compile(module)



def run(self, artifact: Any, trace: Trace) -> Trace:
backend_module = self.backend.load(artifact)
result: Trace = []
for item in trace:
numpy_inputs = recursively_convert_to_numpy(item.inputs)
outputs = getattr(backend_module, item.symbol)(*numpy_inputs)
output = recursively_convert_from_numpy(outputs)
result.append(
TraceItem(symbol=item.symbol,
inputs=item.inputs,
output=output))
return result

0 comments on commit 88a46e7

Please sign in to comment.