Skip to content

Commit

Permalink
[MHLO] Init end to end unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok committed Aug 23, 2022
1 parent 9176b5e commit 355720c
Show file tree
Hide file tree
Showing 14 changed files with 367 additions and 7 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ jobs:
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=eager_mode -v
- name: Run mhlo e2e integration tests
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
run: |
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=mhlo -v
- name: Run tosa e2e integration tests
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
run: |
Expand Down
17 changes: 14 additions & 3 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@

# Available test configs.
from torch_mlir_e2e_test.torchscript.configs import (
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
NativeTorchTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
EagerModeTestConfig
)

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
default='refbackend',
help=f'''
Meaning of options:
"refbackend": run through torch-mlir's RefBackend.
"mhlo": run through torch-mlir's default MHLO backend.
"tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -78,6 +86,9 @@ def main():
if args.config == 'tosa':
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == 'mhlo':
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
xfail_set = {}
Expand Down
132 changes: 132 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,138 @@
"Matmul_vecmat"
}

MHLO_PASS_SET = {
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic",
"NumelModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"MeanModule_basic",
"MeanDynamicSizesModule_basic",
"MeanDimEmptyDimModule_basic",
"NumToTensorFloatModule_basic",
"AtenToDeviceModule_basic",
"AvgPool2dStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Convolution2DStaticModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"SqueezeModule_static",
"TModuleRank1_basic",
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceStartEqEndModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
"ReturnTwoTensorF32I64_basic",
"Matmul4dStatic_basic",
"Matmul_dot",
"Matmul_2d",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dWithIndicesStaticModule_basic",
"MmDagModule_basic",
"MmModule_basic",
"MmModule_chained",
"MaxPool2dStaticModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleFalsePinMemory_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleFalsePinMemory_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"UnsafeViewExpandModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmStaticModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NarrowHorizontalTest2_basic",
"NarrowHorizontalTest_basic",
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"TModuleRank2_basic",
"TensorLiteralModule_basic",
"TensorsConcatModule_basic",
"TensorOpaqueLiteralModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"Permute0RankModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
Expand Down
12 changes: 11 additions & 1 deletion lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "torch-mlir/Conversion/Passes.h"

#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
Expand All @@ -25,4 +28,11 @@ namespace {
#include "torch-mlir/Conversion/Passes.h.inc"
} // end namespace

void mlir::torch::registerConversionPasses() { ::registerPasses(); }
void mlir::torch::registerConversionPasses() {
::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
}
36 changes: 35 additions & 1 deletion lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,41 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
v = mhlo::promoteType(rewriter, v, outType);
}

size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
op, ValueRange(builtinTensors), static_cast<uint64_t>(dim));
op, ValueRange(builtinTensors), posDim);
return success();
}
} // namespace

// AtenNumelOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
AtenNumelOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
size_t rank = selfTy.getRank();

Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
auto loc = op->getLoc();
Value numel =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
for (size_t d = 0 ; d < rank; ++ d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
}

auto outTy = getTypeConverter()->convertType(op.getType());
if (outTy != numel.getType()) {
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
op, outTy, numel);
} else {
rewriter.replaceOp(op, numel);
}
return success();
}
} // namespace
Expand Down Expand Up @@ -1067,5 +1100,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(

INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenNumelOp);
#undef INSERT_ATENOP_PATTERN
}
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
DEPENDS
MhloDialect
ChloDialect
MhloToLinalg
MLIRMhloPassIncGen
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Expand All @@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
MLIRPass
MhloDialect
ChloDialect
MhloToLinalg
TorchMLIRTorchDialect
)

Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) {
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
}
if (inputDims.size() == 0) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
}

for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
Expand Down
18 changes: 18 additions & 0 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
Expand Down Expand Up @@ -310,6 +318,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
if (dSize != 1)
dims.push_back(r);
}
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}

auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
if (failed(newDimSizesInfo))
Expand Down Expand Up @@ -354,6 +367,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
SmallVector<int64_t, 4> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
dims.erase(dims.begin() + dim);
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
Expand Down
2 changes: 0 additions & 2 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def compile(model: torch.nn.Module,
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
else:
scripted = torch.jit.script(model)

# Convert all concrete inputs to TensorPlaceholder's, for consistency.
arg_placeholders = []
for arg in example_args:
Expand Down Expand Up @@ -240,7 +239,6 @@ def compile(model: torch.nn.Module,
""") from None
finally:
sys.stderr = original_stderr

if output_type == OutputType.RAW:
return mb.module

Expand Down
Empty file.
Loading

0 comments on commit 355720c

Please sign in to comment.