diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index bf606b22d30a..960ba274e6dd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -834,6 +834,68 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { }; } // namespace +// Decompose aten.flatten.using_ints into aten.view op. +namespace { +class DecomposeAtenFlattenUsingIntsOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenFlattenUsingIntsOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.self(); + MLIRContext *context = op.getContext(); + int64_t rank = getTensorRank(self); + if (rank < 0) + return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); + + int64_t start, end; + if (!matchPattern(op.start_dim(), m_TorchConstantInt(&start)) || + !matchPattern(op.end_dim(), m_TorchConstantInt(&end))) { + return rewriter.notifyMatchFailure( + op, "unimplemented: requires start and end dims to be constants"); + } + + SmallVector newSizes; + if (rank == 0) { + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + newSizes.push_back(one); + } else { + start = toPositiveDim(start, rank); + end = toPositiveDim(end, rank); + + if (start > end) { + return rewriter.notifyMatchFailure( + op, "expected end dim larger than start dim"); + } + + newSizes.reserve(rank - end + start); + for (size_t k = 0; k < start; ++k) { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(k)); + newSizes.push_back( + rewriter.create(loc, self, /*dim=*/dim)); + } + Value flattenDimSize = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + newSizes.push_back(flattenDimSize); + for (size_t k = end + 1; k < rank; ++k) { + Value dim = + rewriter.create(loc, rewriter.getI64IntegerAttr(k)); + newSizes.push_back( + rewriter.create(loc, self, /*dim=*/dim)); + } + } + Value newSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), newSizes); + rewriter.replaceOpWithNewOp(op, op.getType(), op.self(), + newSizeList); + return success(); + } +}; +} // namespace + // Decompose aten.expand into aten.broadcast_to op. namespace { class DecomposeAtenExpandOp : public OpRewritePattern { @@ -2497,6 +2559,8 @@ class DecomposeComplexOpsPass target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); + patterns.add(context); + target.addIllegalOp(); patterns.add(context); target.addIllegalOp(); patterns.add(context); diff --git a/python/torch_mlir/__init__.py b/python/torch_mlir/__init__.py index e6bbd7e0b4fd..2a716d3927ce 100644 --- a/python/torch_mlir/__init__.py +++ b/python/torch_mlir/__init__.py @@ -126,8 +126,8 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None): # ops in the backend contract, and move these lists somewhere deeper in the # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { - OutputType.TOSA: [], - OutputType.LINALG_ON_TENSORS: [], + OutputType.TOSA: ['torch.aten.flatten.using_ints',], + OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',], OutputType.MHLO: [], } diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 8c37005db913..60d5589ec24d 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -202,8 +202,10 @@ func.func @torch.aten.argmax(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: %[[CST0:.*]] = torch.constant.int 0 // CHECK: %[[CST1:.*]] = torch.constant.int 1 -// CHECK: %[[FLATTEN:.*]] = torch.aten.flatten.using_ints %[[INP]], %[[CST0]], %[[CST1]] : -// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> +// CHECK: %[[CST:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[CST]]-1 : (!torch.int) -> !torch.list +// CHECK: %[[FLATTEN:.*]] = torch.aten.view %[[INP]], %[[T0]] : +// CHECK-SAME: !torch.vtensor<[?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> // CHECK: %[[VAL:.*]], %[[IND:.*]] = torch.aten.max.dim %[[FLATTEN]], %[[CST0]], %[[FALSE]] : // CHECK-SAME: !torch.vtensor<[?],f32>, !torch.int, !torch.bool -> !torch.vtensor<[],f32>, !torch.vtensor<[],si64> // CHECK: return %[[IND]] : !torch.vtensor<[],si64> @@ -1332,3 +1334,19 @@ func.func @torch.aten.std.dim(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vten %0 = torch.aten.std.dim %arg0, %dims, %unbiased, %keepdim: !torch.vtensor<[3,4,5],f32>, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[3,4,1],f32> return %0 : !torch.vtensor<[3,4,1],f32> } + +// ----- +// CHECK-LABEL: func.func @torch.aten.flatten.using_ints( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT3:.*]] = torch.constant.int 3 +// CHECK: %[[INT:.*]]-1 = torch.constant.int -1 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[INT]]-1 : (!torch.int) -> !torch.list +// CHECK: %[[T1:.*]] = torch.aten.view %[[ARG0]], %[[T0]] : !torch.vtensor<[?,?,?,?],f32>, !torch.list -> !torch.vtensor<[?],f32> +// CHECK: return %[[T1]] : !torch.vtensor<[?],f32> +func.func @torch.aten.flatten.using_ints(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?],f32> { + %int0 = torch.constant.int 0 + %int3 = torch.constant.int 3 + %1 = torch.aten.flatten.using_ints %arg0, %int0, %int3: !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int -> !torch.vtensor<[?],f32> + return %1 : !torch.vtensor<[?],f32> +}