Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MHLO] Init end-to-end unit tests #1223

Merged
merged 1 commit into from
Aug 23, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,12 @@ jobs:
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=eager_mode -v
- name: Run mhlo e2e integration tests
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
run: |
export PYTHONPATH="$GITHUB_WORKSPACE/build/tools/torch-mlir/python_packages/torch_mlir"
python -m e2e_testing.torchscript.main --config=mhlo -v
- name: Run tosa e2e integration tests
if: ${{ matrix.os-arch == 'ubuntu-x86_64' && matrix.llvm-build == 'in-tree' }}
run: |
Expand Down
17 changes: 14 additions & 3 deletions e2e_testing/torchscript/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,27 +15,35 @@

# Available test configs.
from torch_mlir_e2e_test.torchscript.configs import (
LazyTensorCoreTestConfig, LinalgOnTensorsBackendTestConfig, NativeTorchTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, EagerModeTestConfig
LazyTensorCoreTestConfig,
LinalgOnTensorsBackendTestConfig,
MhloBackendTestConfig,
NativeTorchTestConfig,
TorchScriptTestConfig,
TosaBackendTestConfig,
EagerModeTestConfig
)

from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend
from torch_mlir_e2e_test.mhlo_backends.linalg_on_tensors import LinalgOnTensorsMhloBackend
from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend

from .xfail_sets import REFBACKEND_XFAIL_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET
from .xfail_sets import REFBACKEND_XFAIL_SET, MHLO_PASS_SET, TOSA_PASS_SET, EAGER_MODE_XFAIL_SET, LTC_XFAIL_SET

# Import tests to register them in the global registry.
from torch_mlir_e2e_test.test_suite import register_all_tests
register_all_tests()

def _get_argparse():
config_choices = ['native_torch', 'torchscript', 'refbackend', 'tosa', 'eager_mode', 'lazy_tensor_core']
config_choices = ['native_torch', 'torchscript', 'refbackend', 'mhlo', 'tosa', 'eager_mode', 'lazy_tensor_core']
parser = argparse.ArgumentParser(description='Run torchscript e2e tests.')
parser.add_argument('-c', '--config',
choices=config_choices,
default='refbackend',
help=f'''
Meaning of options:
"refbackend": run through torch-mlir's RefBackend.
"mhlo": run through torch-mlir's default MHLO backend.
"tosa": run through torch-mlir's default TOSA backend.
"native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration).
"torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly).
Expand Down Expand Up @@ -78,6 +86,9 @@ def main():
if args.config == 'tosa':
config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend())
xfail_set = all_test_unique_names - TOSA_PASS_SET
if args.config == 'mhlo':
config = MhloBackendTestConfig(LinalgOnTensorsMhloBackend())
xfail_set = all_test_unique_names - MHLO_PASS_SET
elif args.config == 'native_torch':
config = NativeTorchTestConfig()
xfail_set = {}
Expand Down
133 changes: 133 additions & 0 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,139 @@
"Matmul_vecmat"
}

MHLO_PASS_SET = {
"FlattenStaticModule_basic",
"FlattenRank0Module_basic",
"TensorsConcatNegativeDimModule_basic",
"NumelModule_basic",
"ReduceSumDimIntListEmptyDimModule_basic",
"SqueezeModule_allUnitDim",
"SqueezeDimModule_unitDim",
"MeanModule_basic",
"MeanDynamicSizesModule_basic",
"MeanDimEmptyDimModule_basic",
"NumToTensorFloatModule_basic",
"AtenToDeviceModule_basic",
"AvgPool2dStaticModule_basic",
"Conv2dWithPaddingDilationStrideStaticModule_basic",
"Convolution2DStaticModule_basic",
"ElementwiseCloneContiguousModule_basic",
"ElementwiseCloneModule_basic",
"ElementwiseBinaryStaticShapeModule_basic",
"ReturnThreeTensorFloat32_basic",
"BoolTensorReturnFalseModule_basic",
"BoolTensorReturnTrueModule_basic",
"BoolTensorReturnMixedModule_basic",
"SqueezeModule_static",
"TModuleRank1_basic",
"TModuleRank0_basic",
"ElementwiseToDtypeIdentityModule_basic",
"View1DFoldModule_basic",
"UnsafeView1DFoldModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SliceModule_basic",
"SliceNegIdxModule_basic",
"SliceOutOfLowerBoundStartIndexModule_basic",
"SliceOutOfUpperBoundIndexModule_basic",
"SliceStartEqEndModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
"ReturnTwoTensorF32I64_basic",
"Matmul4dStatic_basic",
"Matmul_dot",
"Matmul_2d",
"Matmul_matvec",
"Matmul_vecmat",
"MaxPool2dWithIndicesStaticModule_basic",
"MmDagModule_basic",
"MmModule_basic",
"MmModule_chained",
"MaxPool2dStaticModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ZerosModuleDefaultDtype_basic",
"ZerosModuleInt2D_basic",
"ZerosModuleInt3D_basic",
"ZerosModuleFloat2D_basic",
"ZerosModuleFloat3D_basic",
"ZerosModuleFalsePinMemory_basic",
"OnesModuleDefaultDtype_basic",
"OnesModuleInt_basic",
"OnesModuleFloat_basic",
"OnesModuleFalsePinMemory_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewOnesModuleDefaultDtype_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewOnesModuleFloat2D_basic",
"NewOnesModuleFloat3D_basic",
"NewOnesModuleFalsePinMemory_basic",
"DropoutEvalIntModule_basic",
"DropoutEvalFloatModule_basic",
"ContiguousModule_basic",
"DropoutModule_basic",
"ViewCollapseModule_basic",
"ViewCollapseInferredDimModule_basic",
"ViewDynamicExpandCollapseModule_basic",
"ViewDynamicExpandModule_basic",
"ViewExpandModule_basic",
"ViewExpandOnesModule_basic",
"ViewExpandOnesBeforeAndAfterModule_basic",
"ViewExpandOnesMiddleModule_basic",
"ViewExpandCollapseModule_basic",
"ViewExpandCollapseWithOnesModule_basic",
"ViewExpandInferredDimModule_basic",
"ViewNoChangeStaticModule_basic",
"ViewNoChange1dModule_basic",
"ViewNoChange2dModule_basic",
"ViewNoChange3dModule_basic",
"UnsafeViewExpandModule_basic",
"ReduceMaxAllDims_basic",
"ReduceMaxFloatModule_basic",
"ReduceMaxSignedIntModule_basic",
"ReduceMaxUnsignedIntModule_basic",
"ReduceSumDimIntListFloatModule_basic",
"ReduceSumDimIntListIntModule_basic",
"ReduceSumFloatModule_basic",
"ReduceSumSignedIntModule_basic",
"ReduceSumUnsignedIntModule_basic",
"RepeatModule_basic",
"ReshapeAliasCollapseModule_basic",
"ReshapeAliasExpandModule_basic",
"ReshapeExpandModule_basic",
"TestMultipleTensorReturn_basic",
"AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic",
"BaddbmmStaticModule_basic",
"BaddbmmBroadcast1DInputModule_basic",
"BaddbmmBroadcast2DInputModule_basic",
"NarrowHorizontalTest2_basic",
"NarrowHorizontalTest_basic",
"NarrowVerticalTest2_basic",
"NarrowVerticalTest_basic",
"NumToTensorIntModule_basic",
"NumpyTRank0Module_basic",
"NumpyTRank1Module_basic",
"NumpyTRank2Module_basic",
"NumpyTRankNStaticModule_basic",
"NumpyTRankNDynamicModule_basic",
"TModuleRank2_basic",
"TensorLiteralModule_basic",
"TensorsConcatModule_basic",
"TensorOpaqueLiteralModule_basic",
"TransposeIntModule_basic",
"TransposeIntNegDimsModule_basic",
"OnesModuleCPUDevice_basic",
"Permute0RankModule_basic",
"UnsafeViewCollapseModule_basic",
"UnsafeViewDynamicExpandModule_basic",
}

# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
Expand Down
12 changes: 11 additions & 1 deletion lib/Conversion/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

#include "torch-mlir/Conversion/Passes.h"

#ifdef TORCH_MLIR_ENABLE_MHLO
#include "mlir-hlo/Dialect/mhlo/transforms/passes.h"
#endif // TORCH_MLIR_ENABLE_MHLO
#include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h"
#include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h"
#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h"
Expand All @@ -25,4 +28,11 @@ namespace {
#include "torch-mlir/Conversion/Passes.h.inc"
} // end namespace

void mlir::torch::registerConversionPasses() { ::registerPasses(); }
void mlir::torch::registerConversionPasses() {
::registerPasses();
#ifdef TORCH_MLIR_ENABLE_MHLO
::mlir::registerPass([]() -> std::unique_ptr<::mlir::Pass> {
return mlir::mhlo::createLegalizeHloToLinalgPass();
});
#endif // TORCH_MLIR_ENABLE_MHLO
}
36 changes: 35 additions & 1 deletion lib/Conversion/TorchToMhlo/Basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -977,8 +977,41 @@ LogicalResult ConvertAtenOp<AtenCatOp>::matchAndRewrite(
v = mhlo::promoteType(rewriter, v, outType);
}

size_t posDim = toPositiveDim(dim, outType.getRank());
rewriter.replaceOpWithNewOp<mhlo::ConcatenateOp>(
op, ValueRange(builtinTensors), static_cast<uint64_t>(dim));
op, ValueRange(builtinTensors), posDim);
return success();
}
} // namespace

// AtenNumelOp
namespace {
template <>
LogicalResult ConvertAtenOp<AtenNumelOp>::matchAndRewrite(
AtenNumelOp op,
OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const {
auto self = adaptor.self();
auto selfTy = self.getType().dyn_cast<RankedTensorType>();
size_t rank = selfTy.getRank();

Type intType = rewriter.getIntegerType(mhlo::kMhloDimSizeBits);
auto loc = op->getLoc();
Value numel =
rewriter.create<arith::ConstantOp>(loc, rewriter.getIntegerAttr(intType, 1));
for (size_t d = 0 ; d < rank; ++ d) {
Value dimSize = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, self, d));
numel = rewriter.create<arith::MulIOp>(loc, numel, dimSize);
}

auto outTy = getTypeConverter()->convertType(op.getType());
if (outTy != numel.getType()) {
rewriter.replaceOpWithNewOp<arith::ExtSIOp>(
op, outTy, numel);
} else {
rewriter.replaceOp(op, numel);
}
return success();
}
} // namespace
Expand Down Expand Up @@ -1067,5 +1100,6 @@ void mlir::torch::torch_to_mhlo::populateBasicOpPatternsAndLegality(

INSERT_ATENOP_PATTERN(AtenBatchNormOp);
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenNumelOp);
#undef INSERT_ATENOP_PATTERN
}
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToMhlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
DEPENDS
MhloDialect
ChloDialect
MhloToLinalg
MLIRMhloPassIncGen
TorchMLIRConversionPassIncGen

LINK_COMPONENTS
Expand All @@ -24,6 +26,7 @@ add_mlir_conversion_library(TorchMLIRTorchToMhlo
MLIRPass
MhloDialect
ChloDialect
MhloToLinalg
TorchMLIRTorchDialect
)

Expand Down
3 changes: 3 additions & 0 deletions lib/Conversion/TorchToMhlo/Reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,9 @@ LogicalResult ConvertAtenReductionOp<AtenSumDimIntListOp>::matchAndRewrite(
if (!matchPattern(op.dim(), m_TorchConstantIntList(inputDims))) {
return rewriter.notifyMatchFailure(op, "non-int dim list unsupported");
}
if (inputDims.size() == 0) {
inputDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, inputTy.getRank()));
}

for (auto d : inputDims) {
d = toPositiveDim(d, inputTy.getRank());
Expand Down
18 changes: 18 additions & 0 deletions lib/Conversion/TorchToMhlo/ViewLike.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,14 @@ class ConvertAtenViewOp : public OpConversionPattern<AtenOpT> {
numel = rewriter.create<arith::IndexCastOp>(loc, rewriter.getIndexType(),
numel);

if (dimSizes.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
adaptor.self());
return success();
}
Value mhloShape = rewriter.create<tensor::FromElementsOp>(loc, dimSizes);
Value computedShape = rewriter.create<mhlo::ComputeReshapeShapeOp>(
loc, mhloShape.getType(), numel, mhloShape);
Expand Down Expand Up @@ -310,6 +318,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeOp>::matchAndRewrite(
if (dSize != 1)
dims.push_back(r);
}
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}

auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
if (failed(newDimSizesInfo))
Expand Down Expand Up @@ -354,6 +367,11 @@ LogicalResult ConvertAtenOp<AtenSqueezeDimOp>::matchAndRewrite(
SmallVector<int64_t, 4> dims(rank);
std::iota(dims.begin(), dims.end(), 0);
dims.erase(dims.begin() + dim);
if (dims.size() == 0) {
rewriter.replaceOpWithNewOp<mhlo::ReshapeOp>(
op, getTypeConverter()->convertType(op.getType()), self);
return success();
}
auto newDimSizesInfo = mhlo::getDimSizesOfTensor(rewriter, op, self, dims);
if (failed(newDimSizesInfo))
return rewriter.notifyMatchFailure(
Expand Down
2 changes: 0 additions & 2 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,6 @@ def compile(model: torch.nn.Module,
scripted = torch.jit.trace(model, tuple(example_args_for_trace))
else:
scripted = torch.jit.script(model)

# Convert all concrete inputs to TensorPlaceholder's, for consistency.
arg_placeholders = []
for arg in example_args:
Expand Down Expand Up @@ -240,7 +239,6 @@ def compile(model: torch.nn.Module,
""") from None
finally:
sys.stderr = original_stderr

if output_type == OutputType.RAW:
return mb.module

Expand Down
Empty file.
Loading