Skip to content

Commit

Permalink
[StableHLO] Support for slice_scatter (llvm#1960)
Browse files Browse the repository at this point in the history
Co-authored-by: zhekun.zhang <[email protected]>
  • Loading branch information
zhekunz2 and zhekunz2 authored Mar 22, 2023
1 parent 544b5f2 commit 5758a0b
Show file tree
Hide file tree
Showing 9 changed files with 172 additions and 31 deletions.
8 changes: 8 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,8 @@
"RsubIntModule_basic",
"RsubIntModule_noalpha_basic",
"RsubInt0d_NumToTensor_Module_basic",
"SelectScattertModule_basic",
"SelectScattertStaticModule_basic",
"SliceStaticModule_basic",
"SliceModule_basic",
"SliceNegIdxModule_basic",
Expand All @@ -342,6 +344,12 @@
"SliceStartEqEndModule_basic",
"SliceSizeTwoStepModule_basic",
"SliceWholeTensorModule_basic",
"SliceScatterModule_basic",
"SliceScatterNegativeDimModule_basic",
"SliceScatterNegativeEndModule_basic",
"SliceScatterStaticModule_basic",
"SliceScatterStepVariationModule_basic",
"SliceScatterZeroDimModule_basic",
"SqueezeDimModule_static",
"SqueezeDimModule_identity",
"SqueezeModule_broadcast",
Expand Down
3 changes: 3 additions & 0 deletions include/torch-mlir/Conversion/Utils/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,9 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef<int64_t> shape,
Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
std::optional<Type> srcOriginalDtype = std::nullopt);

Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize);
} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
25 changes: 0 additions & 25 deletions lib/Conversion/TorchToLinalg/DataMovement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,31 +33,6 @@ using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

static Value toPositiveValidDim(ConversionPatternRewriter &rewriter,
Location loc, Value torchOptionalInt,
Value builtinInt, Value defaultValue,
Value dimSize) {
if (torchOptionalInt.getType().isa<Torch::NoneType>())
return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
// positveDim < 0 ? 0 : positiveDim
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
Value atLeastZero =
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
loc, sgtDimSize, dimSizeAsInt, atLeastZero);

return castIntToIndex(rewriter, loc, boundedByDimSize);
}

template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToStablehlo/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo
TorchToStablehlo.cpp
StablehloLegalizeUtils.cpp
Basic.cpp
Gather.cpp
GatherScatter.cpp
Linear.cpp
ViewLike.cpp
Reduction.cpp
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,75 @@ Value gatherTensorAlongSingleAxis(PatternRewriter &rewriter, Operation *op,
sliceSizesTensor, dimsAttr)
.getResult();
}

template <typename OpTy, typename OpAdaptor>
LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVector<Value> &resultShape,
SmallVector<Value> &offsets,
SmallVector<Value> &strides) {
Location loc = op.getLoc();
auto input = adaptor.getSelf();
RankedTensorType inputType =
input.getType().template cast<RankedTensorType>();

Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value one = rewriter.create<arith::ConstantIndexOp>(loc, 1);

int64_t dim;
if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim)))
return op->emitError("unimplemented: dim is not constant");

int64_t inputRank = inputType.getRank();
dim = toPositiveDim(dim, inputRank);
if (!isValidDim(dim, inputRank))
return rewriter.notifyMatchFailure(op, "dim is statically invalid");

SmallVector<Value> inputShape = getTensorSizes(rewriter, loc, input);
Value dimSize = inputShape[dim];

Value torchTypeStart = op.getStart();
Value torchTypeEnd = op.getEnd();
Value builtinTypeStart = adaptor.getStart();
Value builtinTypeEnd = adaptor.getEnd();

if (torchTypeStart.getType().isa<OptionalType>() ||
torchTypeEnd.getType().isa<OptionalType>())
return rewriter.notifyMatchFailure(op, "unimplemented optional type arg");

int64_t step;
if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) {
if (!op.getStep().getType().template isa<Torch::NoneType>())
return op->emitError("unimplemented: step is not constant");
step = 1;
}

Value start = toPositiveValidDim(rewriter, loc, torchTypeStart,
builtinTypeStart, zero, dimSize);
Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd,
dimSize, dimSize);

// end >= start ? end : start
Value endSgeStart = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sge, end, start);
end = rewriter.create<arith::SelectOp>(loc, endSgeStart, end, start);
Value stepIndex = rewriter.create<arith::ConstantIndexOp>(loc, step);

// Slice logic: resultSize = floordiv(end - start + step - 1, step)
resultShape = getTensorSizes(rewriter, loc, input);
Value len = rewriter.create<arith::SubIOp>(loc, end, start);
Value resultSize = rewriter.create<arith::AddIOp>(loc, len, stepIndex);
resultSize = rewriter.create<arith::SubIOp>(loc, resultSize, one);
resultSize = rewriter.create<arith::FloorDivSIOp>(loc, resultSize, stepIndex);
resultShape[dim] = resultSize;

strides.resize(inputType.getRank(), one);
offsets.resize(inputType.getRank(), zero);

offsets[dim] = start;
strides[dim] = rewriter.create<arith::MulIOp>(loc, strides[dim], stepIndex);
return success();
}
} // namespace

// Ref:
Expand Down Expand Up @@ -258,9 +327,54 @@ LogicalResult ConvertAtenOp<AtenGatherOp>::matchAndRewrite(
return success();
}

void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
// AtenSliceScatterOp
template <>
LogicalResult ConvertAtenOp<AtenSliceScatterOp>::matchAndRewrite(
AtenSliceScatterOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {

if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();

Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();

auto input = adaptor.getSelf();

RankedTensorType resultType =
typeConverter->convertType(op->getResult(0).getType())
.cast<RankedTensorType>();

SmallVector<Value> resultShape;
SmallVector<Value> offsets;
SmallVector<Value> strides;
if (failed(prepareArgumentsForSlicingOp<AtenSliceScatterOp,
AtenSliceScatterOpAdaptor>(
op, adaptor, rewriter, resultShape, offsets, strides))) {
return failure();
}

Value src = adaptor.getSrc();
auto srcType = src.getType().cast<RankedTensorType>();
int64_t srcRank = srcType.getRank();
SmallVector<int64_t> srcAbstractSizes(srcRank, kUnknownSize);
auto abstractSrcType = RankedTensorType::get(
makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType());
Value abstractSrc =
rewriter.create<tensor::CastOp>(loc, abstractSrcType, src);

Value result = rewriter.create<tensor::InsertSliceOp>(
loc, abstractSrc, input, offsets, resultShape, strides);

rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, result);

return success();
}

void mlir::torch::torch_to_stablehlo::
populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options) {
MLIRContext *context = patterns.getContext();

#define INSERT_ATENOP_PATTERN(AtenOp) \
Expand All @@ -269,5 +383,6 @@ void mlir::torch::torch_to_stablehlo::populateGatherOpPatternsAndLegality(
INSERT_ATENOP_PATTERN(AtenEmbeddingOp);
INSERT_ATENOP_PATTERN(AtenIndexSelectOp);
INSERT_ATENOP_PATTERN(AtenGatherOp);
INSERT_ATENOP_PATTERN(AtenSliceScatterOp);
#undef INSERT_ATENOP_PATTERN
}
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToStablehlo/PopulatePatterns.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ void populateBasicOpPatternsAndLegality(TypeConverter &typeConverter,
void populateViewLikeOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateGatherOpPatternsAndLegality(
void populateGatherScatterOpPatternsAndLegality(
TypeConverter &typeConverter, RewritePatternSet &patterns,
ConversionTarget &target, const TorchToStablehloOptions &options);
void populateReductionOpPatternsAndLegality(
Expand Down
2 changes: 1 addition & 1 deletion lib/Conversion/TorchToStablehlo/TorchToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class ConvertTorchToStablehlo
typeConverter, patterns, target, options);
torch_to_stablehlo::populateViewLikeOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateGatherOpPatternsAndLegality(
torch_to_stablehlo::populateGatherScatterOpPatternsAndLegality(
typeConverter, patterns, target, options);
torch_to_stablehlo::populateReductionOpPatternsAndLegality(
typeConverter, patterns, target, options);
Expand Down
23 changes: 23 additions & 0 deletions lib/Conversion/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,29 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype,
llvm_unreachable("convertScalarToDtype should handle all the types");
}

Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc,
Value torchOptionalInt, Value builtinInt,
Value defaultValue, Value dimSize) {
if (torchOptionalInt.getType().isa<Torch::NoneType>())
return defaultValue;
auto dimSizeAsInt = castIndexToInt64(rewriter, loc, dimSize);
Value positiveDim =
toPositiveDimDynamic(rewriter, loc, builtinInt, dimSizeAsInt);
// positiveDim < 0 ? 0 : positiveDim
Value cst0 = rewriter.create<arith::ConstantOp>(
loc, rewriter.getZeroAttr(dimSizeAsInt.getType()));
Value predDimSltZero = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, positiveDim, cst0);
Value atLeastZero =
rewriter.create<arith::SelectOp>(loc, predDimSltZero, cst0, positiveDim);
// atLeastZero > dimSizeAsInt ? dimSizeAsInt : atLeastZero
Value sgtDimSize = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::sgt, atLeastZero, dimSizeAsInt);
Value boundedByDimSize = rewriter.create<arith::SelectOp>(
loc, sgtDimSize, dimSizeAsInt, atLeastZero);

return castIntToIndex(rewriter, loc, boundedByDimSize);
}
} // namespace Torch
} // namespace torch
} // namespace mlir
17 changes: 17 additions & 0 deletions python/torch_mlir_e2e_test/test_suite/slice_like.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,23 @@ def forward(self, x, src):
def SliceScatterZeroDimModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(1, 8))

class SliceScatterNegativeEndModule(torch.nn.Module):
def __init__(self):
super().__init__()

@export
@annotate_args([
None,
([-1, -1], torch.float32, True),
([-1, -1], torch.float32, True),
])
def forward(self, x, src):
return torch.ops.aten.slice_scatter(x, src, dim = 0, start = 3, end = -1, step = 1)


@register_test_case(module_factory=lambda: SliceScatterNegativeEndModule())
def SliceScatterNegativeEndModule_basic(module, tu: TestUtils):
module.forward(tu.rand(6, 8), tu.rand(2, 8))

class SliceScatterNegativeDimModule(torch.nn.Module):

Expand Down

0 comments on commit 5758a0b

Please sign in to comment.