From 6b2c214929c37d7691429df1e13fb57f3e0e0ad3 Mon Sep 17 00:00:00 2001 From: JakopinA Date: Mon, 21 Nov 2022 07:16:35 -0600 Subject: [PATCH] Add E2E support for aten.as_strided --- e2e_testing/xfail_sets.py | 3 +- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 12 ++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 30 +++++++++++++++++++ lib/Dialect/Torch/Transforms/RefineTypes.cpp | 2 +- lib/Dialect/Torch/Transforms/ShapeLibrary.cpp | 4 +++ python/torch_mlir/__init__.py | 4 +++ .../jit_ir/build_tools/shape_lib_gen.py | 3 ++ .../jit_ir/build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 20 +++++++++++++ 10 files changed, 103 insertions(+), 2 deletions(-) diff --git a/e2e_testing/xfail_sets.py b/e2e_testing/xfail_sets.py index c95722946993..3853757278c6 100644 --- a/e2e_testing/xfail_sets.py +++ b/e2e_testing/xfail_sets.py @@ -607,6 +607,7 @@ "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", "Aten_EmbeddingBagExample_basic", + "AsStridedModule_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Float_basic", "ElementwiseRemainderScalarModule_Int_basic", @@ -627,5 +628,5 @@ "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule3D_basic", "VarMeanCorrectionModule_basic", - "VarMeanCorrectionNoneModule_basic" + "VarMeanCorrectionNoneModule_basic", } diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 87e96fda879a..aa770b416767 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7466,6 +7466,32 @@ def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ }]; } +def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$size, + AnyTorchListOfTorchIntType:$stride, + AnyTorchOptionalIntType:$storage_offset + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenAsStridedOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 45d617dc5fa7..0775c0135e21 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2132,6 +2132,18 @@ OpFoldResult AtenSqrtIntOp::fold(ArrayRef operands) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenAsStridedOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenAsStridedOp::fold(ArrayRef operands) { + if (auto tensorType = getOperand(0).getType().dyn_cast()) { + if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) + return getOperand(0); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // PrimDtypeOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 2115d3aece4e..a93c0214be51 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2948,6 +2948,34 @@ class DecomposeAtenSelectScatterOp }; } // namespace +namespace { +class DecomposeAtenAsStridedOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenAsStridedOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + Value size = op.size(); + Value stride = op.stride(); + Value storage_offset = op.storage_offset(); + + //Value result = + + + //Value src = rewriter.create(loc, op.self().getType(), size, stride, storage_offset); + rewriter.replaceOpWithNewOp( + op, op.getType(), self, size, stride, storage_offset); + + + + + return success(); + } +}; +} // namespace + namespace { class DecomposeAten_EmbeddingBagOp : public OpRewritePattern { @@ -3328,6 +3356,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/lib/Dialect/Torch/Transforms/RefineTypes.cpp b/lib/Dialect/Torch/Transforms/RefineTypes.cpp index 87ebacd5b157..e5c8ccceb8da 100644 --- a/lib/Dialect/Torch/Transforms/RefineTypes.cpp +++ b/lib/Dialect/Torch/Transforms/RefineTypes.cpp @@ -699,7 +699,7 @@ void TypeAnalysis::visitOperation(Operation *op, AtenCopyOp, AtenZeroOp, AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp, PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp, AtenMaskedFillTensorOp, AtenRollOp, AtenPowTensorTensorOp, - AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, + AtenLiftFreshCopyOp, AtenIndexTensorHackedTwinOp, AtenAsStridedOp, AtenUpsampleNearest2dVecOp, AtenMishOp, AtenRoundOp, AtenFillTensorOp, AtenUpsampleNearest2dBackwardOp>(op)) { return incorporateKnowledge(op->getResult(0), operands[0]->getValue()); diff --git a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp index b500e3ebc48a..cc641a1d50dc 100644 --- a/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/ShapeLibrary.cpp @@ -6601,6 +6601,10 @@ StringRef mlir::torch::Torch::getShapeLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.embedding_bag.padding_idx\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.bool, %arg4: !torch.int, %arg5: !torch.bool, %arg6: !torch.optional>, %arg7: !torch.bool, %arg8: !torch.optional) -> !torch.tuple, list, list, list> {\n" " %0 = call @__torch__._embedding_bag_helper(%arg0, %arg1, %arg2, %arg7, %arg4) : (!torch.list, !torch.list, !torch.list, !torch.bool, !torch.int) -> !torch.tuple, list, list, list>\n" " return %0 : !torch.tuple, list, list, list>\n" diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index 2cc6e984ccb8..28eded2baea4 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -344,6 +344,7 @@ def compile(model: torch.nn.Module, mb = ModuleBuilder() import_options = ImportOptions() import_options.ignoreExistingTensorShapesAndDtypes = ignore_traced_shapes + try: original_stderr = sys.stderr sys.stderr = StringIO() @@ -363,12 +364,15 @@ def compile(model: torch.nn.Module, return mb.module option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" + #mb.module.dump() run_pipeline_with_repro_report( mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", "Lowering TorchScript IR -> Torch Backend IR", ) + + if verbose: print("\n====================") print("Torch Backend IR") diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py index 1f3ee05af288..b30eeb25f411 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/shape_lib_gen.py @@ -1027,6 +1027,9 @@ def aten〇index_put(self: List[int], indices: List[Optional[List[int]]], values def aten〇index_put〇hacked_twin(self: List[int], indices: List[List[int]], values: List[int], accumulate: bool = False) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇as_strided(self: List[int], size: List[int], stride: List[int], storage_offset: Optional[int] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇embedding(weight: List[int], indices: List[int], padding_idx: int = -1, scale_grad_by_freq: bool = False, sparse: bool = False) -> List[int]: return upstream_shape_functions.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse) diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py index ae13c33f7b88..144f51cb7264 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/torch_ods_gen.py @@ -508,6 +508,7 @@ def emit_with_mutating_variants(key, **kwargs): # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") + emit("aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)", has_folder=True) emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 0eee76b1ba3d..453b123c587b 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -2954,6 +2954,26 @@ def Aten_EmbeddingBagExample_basic(module, tu: TestUtils): # ============================================================================== +class AsStridedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + + def forward(self): + x = torch.randn(5, 5) + print (x) + return torch.ops.aten.as_strided(x, (2, 2), (5, 3)) + +@register_test_case(module_factory=lambda: AsStridedModule()) +def AsStridedModule_basic(module, tu: TestUtils): + module.forward() + +# ============================================================================== + class CumsumModule(torch.nn.Module): def __init__(self):