Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Decompose torch.slice_scatter (llvm#1622)
Browse files Browse the repository at this point in the history
* Decompose torch.slice_scatter

* fix compilation error

* update file check

* fix ci

* fix i64 torch.tensor dtype
  • Loading branch information
Tanyo Kwok committed Dec 5, 2022
1 parent 408c136 commit 7aff051
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 11 deletions.
124 changes: 124 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2949,6 +2949,128 @@ class DecomposeAtenSelectScatterOp
};
} // namespace

namespace {
// def slice_scatter(self, values, dim, start, end, step):
// size = self.size(dim)
// indices = torch.arange(size)
// shift_indices = indices - start
// mask = shift_indices % step == 0
// start_mask = shift_indices >= 0
// end_mask = shift_indices < end
// mask = mask * start_mask
// mask = mask * end_mask
// sizes = list(self.size())
// rank = len(sizes)
// shape = [1] * rank
// shape[dim] = size
// mask = mask.view(shape)
// return torch.where(mask, values, self)
//
class DecomposeAtenSliceScatterOp
: public OpRewritePattern<AtenSliceScatterOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenSliceScatterOp op,
PatternRewriter &rewriter) const override {
int64_t inputRank = getTensorRank(op.self());
int64_t dimInt = 0;
if (matchPattern(op.dim(), m_TorchConstantInt(&dimInt))) {
dimInt = toPositiveDim(dimInt, inputRank);
if (!isValidDim(dimInt, inputRank))
return rewriter.notifyMatchFailure(op, "dim is not a valid dim");
} else {
return rewriter.notifyMatchFailure(op, "dim must be constant");
}

auto getOptionalVal = [&](Value val, Value defVal) -> Value {
if (val.getType().isa<Torch::NoneType>()) {
return defVal;
} else {
return val;
}
};

Value one = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(1));
Value zero = rewriter.create<Torch::ConstantIntOp>(
op.getLoc(), rewriter.getI64IntegerAttr(0));
Value none = rewriter.create<ConstantNoneOp>(op.getLoc());
Value dimSize =
rewriter.create<AtenSizeIntOp>(op.getLoc(), op.self(), op.dim());

Value start = getOptionalVal(op.start(), zero);
Value end = getOptionalVal(op.end(), dimSize);
Value step = getOptionalVal(op.step(), one);
// Step 0. create indices
Type indicesType = ValueTensorType::get(
op.getContext(), ArrayRef<int64_t>{ShapedType::kDynamicSize},
IntegerType::get(op.getContext(), 64, IntegerType::Signed));
Value indices = rewriter.create<AtenArangeOp>(
op.getLoc(), indicesType, dimSize, none, none, none, none);

// Step 1. make indices broadcastable to self's shape
SmallVector<int64_t> newIndicesShapeInt(inputRank, 1);
SmallVector<Value> newIndicesShape(inputRank, one);
newIndicesShape[dimInt] = dimSize;
newIndicesShapeInt[dimInt] = ShapedType::kDynamicSize;
Value newIndicesSizeList = rewriter.create<PrimListConstructOp>(
op.getLoc(), ListType::get(IntType::get(op.getContext())),
newIndicesShape);
Type indicesDtype = indices.getType().cast<ValueTensorType>().getDtype();
Type newIndicesType = ValueTensorType::get(
op.getContext(), llvm::makeArrayRef(newIndicesShapeInt), indicesDtype);
indices = rewriter.create<AtenViewOp>(op.getLoc(), newIndicesType,
indices, newIndicesSizeList);

// Step 2. calculate scatter indices mask
Type maskType = ValueTensorType::get(
op.getContext(), newIndicesType.cast<ValueTensorType>().getSizes(),
IntegerType::get(op.getContext(), 1));
auto shiftIndices = rewriter.create<AtenSubScalarOp>(
op.getLoc(), indices.getType(), indices, start, one);
auto stepRemainder = rewriter.create<AtenRemainderScalarOp>(
op.getLoc(), indices.getType(), shiftIndices, step);
Value mask = rewriter.create<AtenEqScalarOp>(op.getLoc(), maskType,
stepRemainder, zero);
auto maskStart = rewriter.create<AtenGeScalarOp>(op.getLoc(), maskType,
shiftIndices, zero);
auto maskEnd =
rewriter.create<AtenLtScalarOp>(op.getLoc(), maskType, indices, end);
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
maskStart);
mask = rewriter.create<AtenBitwiseAndTensorOp>(op.getLoc(), maskType, mask,
maskEnd);

// Step 3. make src broadcastable to self's shape
Value src = op.src();
BaseTensorType srcTensorType = src.getType().cast<BaseTensorType>();
if (!srcTensorType.hasSizes())
return rewriter.notifyMatchFailure(op, "src tensor must have size");

ArrayRef<int64_t> srcShape = srcTensorType.getSizes();
int64_t srcRank = srcShape.size();
if (srcRank != inputRank) {
if (srcRank + 1 == inputRank) {
SmallVector<int64_t> sizes;
sizes.append(srcShape.begin(), srcShape.end());
sizes.insert(sizes.begin() + dimInt, 1);
Type srcType = srcTensorType.getWithSizesAndDtype(
llvm::makeArrayRef(sizes), srcTensorType.getDtype());
src = rewriter.create<AtenUnsqueezeOp>(op.getLoc(), srcType, src,
op.dim());
} else {
return rewriter.notifyMatchFailure(op, "src's rank doesn't match");
}
}

// Step 4. replace output = mask? src: self
rewriter.replaceOpWithNewOp<AtenWhereSelfOp>(op, op.getType(), mask,
src, op.self());
return success();
}
};
} // namespace

namespace {
class DecomposeAten_EmbeddingBagOp
: public OpRewritePattern<Aten_EmbeddingBagOp> {
Expand Down Expand Up @@ -3342,6 +3464,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenNumpyTOp>();
patterns.add<DecomposeAtenSelectScatterOp>(context);
target.addIllegalOp<AtenSelectScatterOp>();
patterns.add<DecomposeAtenSliceScatterOp>(context);
target.addIllegalOp<AtenSliceScatterOp>();
patterns.add<DecomposeAtenVarDimOp>(context);
target.addIllegalOp<AtenVarDimOp>();
patterns.add<DecomposeAtenVarCorrectionOp>(context);
Expand Down
4 changes: 2 additions & 2 deletions python/torch_mlir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,8 @@ def like(tensor: torch.Tensor, dynamic_axes: List[int] = None):
# ops in the backend contract, and move these lists somewhere deeper in the
# compiler where each backend can "own" its set of legal ops.
BACKEND_LEGAL_OPS = {
OutputType.TOSA: ['torch.aten.flatten.using_ints','torch.aten.native_layer_norm','torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints',],
OutputType.TOSA: ['torch.aten.flatten.using_ints', 'torch.aten.native_layer_norm', 'torch.aten.linear'],
OutputType.LINALG_ON_TENSORS: ['torch.aten.flatten.using_ints', 'torch.aten.slice_scatter'],
OutputType.MHLO: [],
}

Expand Down
33 changes: 24 additions & 9 deletions test/Dialect/Torch/decompose-complex-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -760,7 +760,7 @@ func.func @torch.aten.numpy_T$rank_three(%arg0: !torch.vtensor<[5,4,3],f32>) ->
}

// -----
// CHECK-LABEL: func.func @torch.aten.repeat(
// CHECK-LABEL: func @torch.aten.repeat(
// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.int, %[[ARG2:.*]]: !torch.int, %[[ARG3:.*]]: !torch.int) -> !torch.vtensor<[?,?,?],f32> {
// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG1]], %[[ARG2]], %[[ARG3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
// CHECK: %[[INT1:.*]] = torch.constant.int 1
Expand All @@ -786,14 +786,29 @@ func.func @torch.aten.repeat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.int
// -----
// CHECK-LABEL: func @torch.aten.select_scatter
// CHECK-SAME: (%[[SELF:.*]]: !torch.vtensor<[?,?],f32>, %[[SRC:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK-NEXT: %[[START:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[DIM:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[STEP:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[END:.*]] = torch.aten.add.int %[[START]], %[[STEP]]
// CHECK-NEXT: %[[UNSQUEEZE_SRC:.*]] = torch.aten.unsqueeze %[[SRC]], %[[DIM]]
// CHECK-NEXT: %[[SLICE_SCATTER:.*]] = torch.aten.slice_scatter %[[SELF]], %[[UNSQUEEZE_SRC]], %[[DIM]], %[[START]], %[[END]], %[[STEP]]
// CHECK-NEXT: return %[[SLICE_SCATTER]]
// CHECK-NEXT: }
// CHECK-NEXT: %[[INT0:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[INT1:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[INT1_0:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[T0:.*]] = torch.aten.add.int %[[INT0]], %[[INT1_0]] : !torch.int, !torch.int -> !torch.int
// CHECK-NEXT: %[[T1:.*]] = torch.aten.unsqueeze %[[SRC]], %[[INT1]] : !torch.vtensor<[?],f32>, !torch.int -> !torch.vtensor<[?,1],f32>
// CHECK-NEXT: %[[INT1_1:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[INT0_2:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[NONE:.*]] = torch.constant.none
// CHECK-NEXT: %[[T2:.*]] = torch.aten.size.int %[[SELF]], %[[INT1]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.int
// CHECK-NEXT: %[[INT0_3:.*]] = torch.constant.int 0
// CHECK-NEXT: %[[INT1_4:.*]] = torch.constant.int 1
// CHECK-NEXT: %[[T3:.*]] = torch.aten.arange.start_step %[[INT0_3]], %[[T2]], %[[INT1_4]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?],si64>
// CHECK-NEXT: %[[T4:.*]] = torch.prim.ListConstruct %[[INT1_1]], %[[T2]] : (!torch.int, !torch.int) -> !torch.list<int>
// CHECK-NEXT: %[[T5:.*]] = torch.aten.view %[[T3]], %[[T4]] : !torch.vtensor<[?],si64>, !torch.list<int> -> !torch.vtensor<[1,?],si64>
// CHECK-NEXT: %[[T6:.*]] = torch.aten.sub.Scalar %[[T5]], %[[INT0]], %[[INT1_1]] : !torch.vtensor<[1,?],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,?],si64>
// CHECK-NEXT: %[[T7:.*]] = torch.aten.remainder.Scalar %[[T6]], %[[INT1_0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],si64>
// CHECK-NEXT: %[[T8:.*]] = torch.aten.eq.Scalar %[[T7]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T9:.*]] = torch.aten.ge.Scalar %[[T6]], %[[INT0_2]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T10:.*]] = torch.aten.lt.Scalar %[[T5]], %[[T0]] : !torch.vtensor<[1,?],si64>, !torch.int -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T11:.*]] = torch.aten.bitwise_and.Tensor %[[T8]], %[[T9]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T12:.*]] = torch.aten.bitwise_and.Tensor %[[T11]], %[[T10]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[1,?],i1> -> !torch.vtensor<[1,?],i1>
// CHECK-NEXT: %[[T13:.*]] = torch.aten.where.self %[[T12]], %[[T1]], %[[SELF]] : !torch.vtensor<[1,?],i1>, !torch.vtensor<[?,1],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
// CHECK-NEXT: return %[[T13]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.select_scatter(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> {
%int0 = torch.constant.int 0
%int1 = torch.constant.int 1
Expand Down

0 comments on commit 7aff051

Please sign in to comment.