Skip to content

Commit

Permalink
[MLIR][TORCH] Add E2E support for aten.topk op
Browse files Browse the repository at this point in the history
This commit adds the decomposition for the aten.topk op.

Signed-Off By: Vivek Khandelwal<[email protected]>
  • Loading branch information
vivekkhandelwal1 committed May 5, 2023
1 parent 1eceb84 commit 378860f
Show file tree
Hide file tree
Showing 5 changed files with 90 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7231,6 +7231,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0 = call @__torch__.torch.jit._shape_functions.topk(%arg0, %arg1, %arg2) : (!torch.list<int>, !torch.int, !torch.int) -> !torch.tuple<list<int>, list<int>>\n"
" return %0 : !torch.tuple<list<int>, list<int>>\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.topk\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.bool, %arg4: !torch.bool) -> !torch.tuple<int, int> {\n"
" %int4 = torch.constant.int 4\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple<int, int>\n"
" return %1 : !torch.tuple<int, int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.conv2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<list<int>>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.list<int>, %arg6: !torch.int) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list<int>, !torch.list<int>, !torch.optional<list<int>>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.int) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down
44 changes: 44 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4320,6 +4320,49 @@ class DecomposeAtenVarMeanDimOp : public OpRewritePattern<AtenVarMeanDimOp> {
};
} // namespace

namespace {
// Decompose `aten.topk` op into `aten.sort` and `aten.slice.Tensor` op.
class DecomposeAtenTopkOp : public OpRewritePattern<AtenTopkOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenTopkOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();

bool sorted;
if (!matchPattern(op.getSorted(), m_TorchConstantBool(&sorted)))
return rewriter.notifyMatchFailure(
op, "Expected a constant boolean value for sorted");
if (!sorted)
return rewriter.notifyMatchFailure(
op, "unimplemented: sorted value arg must be set to True");

Value self = op.getSelf();
Value dim = op.getDim();
auto selfType = self.getType().cast<BaseTensorType>();
auto sortIndicesType = selfType.getWithSizesAndDtype(
selfType.getOptionalSizes(),
IntegerType::get(context, 64, IntegerType::Signed));
auto sortOpResult = rewriter.create<AtenSortOp>(
loc, self.getType(), sortIndicesType, self, dim,
/*descending=*/op.getLargest());
Value start = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value step = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));
Value resultValue = rewriter.create<AtenSliceTensorOp>(
loc, op->getResultTypes()[0], sortOpResult->getResult(0), dim, start,
/*end=*/op.getK(), step);
Value resultIndices = rewriter.create<AtenSliceTensorOp>(
loc, op->getResultTypes()[1], sortOpResult->getResult(1), dim, start,
/*end=*/op.getK(), step);
rewriter.replaceOp(op, {resultValue, resultIndices});
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -4483,6 +4526,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenOneHotOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenCrossEntropyLossOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenVarMeanDimOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenTopkOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -477,6 +477,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context,
target.addIllegalOp<AtenOneHotOp>();
target.addIllegalOp<AtenCrossEntropyLossOp>();
target.addIllegalOp<AtenVarMeanDimOp>();
target.addIllegalOp<AtenTopkOp>();
for (auto &opName : backendLegalOpsSet) {
target.addLegalOp(
OperationName(kTorchOpPrefix + opName.first().str(), context));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,10 @@ def aten〇addcdiv〡shape(self: List[int], tensor1: List[int], tensor2: List[in
def aten〇topk〡shape(self: List[int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[List[int], List[int]]:
return upstream_shape_functions.topk(self, k, dim)

def aten〇topk〡dtype(self_rank_dtype: Tuple[int, int], k: int, dim: int = -1, largest: bool = True, sorted: bool = True) -> Tuple[int, int]:
_, self_dtype = self_rank_dtype
return self_dtype, torch.int64

def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1), padding: List[int] = (0, 0), dilation: List[int] = (1, 1), groups: int = 1) -> List[int]:
return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups)

Expand Down
35 changes: 35 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3723,3 +3723,38 @@ def forward(self):
@register_test_case(module_factory=lambda: ConstantBoolParameterModule())
def ConstantBoolParameterModule_basic(module, tu: TestUtils):
module.forward()


# ==============================================================================


class AtenTopKModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.topk(x, k=50, dim=-1, largest=True, sorted=True)


@register_test_case(module_factory=lambda: AtenTopKModule())
def AtenTopKModule_basic(module, tu: TestUtils):
module.forward(tu.rand(1, 100))


class AtenTopKSmallestModule(torch.nn.Module):

def __init__(self):
super().__init__()

@export
@annotate_args([None, ([-1, -1, -1], torch.float32, True)])
def forward(self, x):
return torch.ops.aten.topk(x, k=20, dim=1, largest=False, sorted=True)


@register_test_case(module_factory=lambda: AtenTopKSmallestModule())
def AtenTopKSmallestModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 40, 50))

0 comments on commit 378860f

Please sign in to comment.