From 378860f51b50abacef27fb3dac96010c80488ae5 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 2 May 2023 13:29:00 +0000 Subject: [PATCH] [MLIR][TORCH] Add E2E support for aten.topk op This commit adds the decomposition for the aten.topk op. Signed-Off By: Vivek Khandelwal --- .../Transforms/AbstractInterpLibrary.cpp | 6 +++ .../Torch/Transforms/DecomposeComplexOps.cpp | 44 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + .../build_tools/abstract_interp_lib_gen.py | 4 ++ .../torch_mlir_e2e_test/test_suite/basic.py | 35 +++++++++++++++ 5 files changed, 90 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index abcbf2ac4298..b68d62c8b2ae 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7231,6 +7231,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.tuple, list>\n" " return %0 : !torch.tuple, list>\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.topk\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c245458cdd3..f5de87671fb0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4320,6 +4320,49 @@ class DecomposeAtenVarMeanDimOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op. +class DecomposeAtenTopkOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTopkOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto context = op.getContext(); + + bool sorted; + if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted))) + return rewriter.notifyMatchFailure( + op, "Expected a constant boolean value for sorted"); + if (!sorted) + return rewriter.notifyMatchFailure( + op, "unimplemented: sorted value arg must be set to True"); + + Value self = op.getSelf(); + Value dim = op.getDim(); + auto selfType = self.getType().cast(); + auto sortIndicesType = selfType.getWithSizesAndDtype( + selfType.getOptionalSizes(), + IntegerType::get(context, 64, IntegerType::Signed)); + auto sortOpResult = rewriter.create( + loc, self.getType(), sortIndicesType, self, dim, + /*descending=*/op.getLargest()); + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value step = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value resultValue = rewriter.create( + loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start, + /*end=*/op.getK(), step); + Value resultIndices = rewriter.create( + loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start, + /*end=*/op.getK(), step); + rewriter.replaceOp(op, {resultValue, resultIndices}); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -4483,6 +4526,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 2c712ffb4ba7..ac077ca2f831 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -477,6 +477,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py index 75bc466ab501..90c6f30d66f9 100644 --- a/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py +++ b/python/torch_mlir/dialects/torch/importer/jit_ir/build_tools/abstract_interp_lib_gen.py @@ -822,6 +822,10 @@ def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[in def aten〇topk〡shape(self: List[int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[List[int], List[int]]: return upstream_shape_functions.topk(self, k, dim) +def aten〇topk〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[int, int]: + _, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) diff --git a/python/torch_mlir_e2e_test/test_suite/basic.py b/python/torch_mlir_e2e_test/test_suite/basic.py index 5ea748fe79d5..e312ebcb6787 100644 --- a/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3723,3 +3723,38 @@ def forward(self): @register_test_case(module_factory=lambda: ConstantBoolParameterModule()) def ConstantBoolParameterModule_basic(module, tu: TestUtils): module.forward() + + +# ============================================================================== + + +class AtenTopKModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.topk(x, k=50, dim=-1, largest=True, sorted=True) + + +@register_test_case(module_factory=lambda: AtenTopKModule()) +def AtenTopKModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 100)) + + +class AtenTopKSmallestModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([None, ([-1, -1, -1], torch.float32, True)]) + def forward(self, x): + return torch.ops.aten.topk(x, k=20, dim=1, largest=False, sorted=True) + + +@register_test_case(module_factory=lambda: AtenTopKSmallestModule()) +def AtenTopKSmallestModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 40, 50))