Skip to content

Commit

Permalink
[torch-dialect] emit aten.as_strided op, add folder and demposition r…
Browse files Browse the repository at this point in the history
…ule for it
  • Loading branch information
Vremold committed Jul 5, 2023
1 parent 157e5e5 commit 1e04260
Show file tree
Hide file tree
Showing 10 changed files with 283 additions and 26 deletions.
2 changes: 2 additions & 0 deletions e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,7 @@
}

STABLEHLO_PASS_SET = {
"AsStridedStaticModule_basic",
"AliasModule_basic",
"AllBoolFalseModule_basic",
"AllBoolTrueModule_basic",
Expand Down Expand Up @@ -806,6 +807,7 @@
# Write the TOSA set as a "passing" set as it is very early in development
# and very few tests work yet.
TOSA_PASS_SET = {
"AsStridedStaticModule_basic",
"AliasModule_basic",
"MaxPool2dEmptyStrideStaticModule_basic",
"ConstantBoolParameterModule_basic",
Expand Down
26 changes: 26 additions & 0 deletions include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -8028,6 +8028,32 @@ def Torch_AtenViewOp : Torch_Op<"aten.view", [
let hasFolder = 1;
}

def Torch_AtenAsStridedOp : Torch_Op<"aten.as_strided", [
AllowsTypeRefinement,
ReadOnly
]> {
let summary = "Generated op for `aten::as_strided : (Tensor, int[], int[], int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$size,
AnyTorchListOfTorchIntType:$stride,
AnyTorchOptionalIntType:$storage_offset
);
let results = (outs
AnyTorchTensorType:$result
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenAsStridedOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 4, 1);
}
void AtenAsStridedOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 4, 1);
}
}];
let hasFolder = 1;
}

def Torch_Aten_UnsafeViewOp : Torch_Op<"aten._unsafe_view", [
AllowsTypeRefinement,
HasValueSemantics,
Expand Down
100 changes: 76 additions & 24 deletions lib/Dialect/Torch/IR/TorchOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,8 +165,8 @@ static Value getScalarIntValue(Value input, Location loc,
//===----------------------------------------------------------------------===//

LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
auto func =
symbolTable.lookupNearestSymbolFrom<func::FuncOp>(*this, getFunctionAttr());
auto func = symbolTable.lookupNearestSymbolFrom<func::FuncOp>(
*this, getFunctionAttr());
if (!func)
return emitError() << "'@" << getFunction()
<< "' does not reference a valid function";
Expand Down Expand Up @@ -419,11 +419,13 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
// If the condition is constant, delete the dead branch and inline the live
// branch.
patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) {
auto constantBool = op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
auto constantBool =
op.getCondition().getDefiningOp<Torch::ConstantBoolOp>();
if (!constantBool)
return rewriter.notifyMatchFailure(op, "non-constant condition");
replaceOpWithRegion(
rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion());
replaceOpWithRegion(rewriter, op,
constantBool.getValue() ? op.getThenRegion()
: op.getElseRegion());
return success();
});
// If the thenRegion and elseRegion yield the same Value's, then use those
Expand Down Expand Up @@ -481,14 +483,16 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
continue;
newResultTypes.push_back(op->getResult(i).getType());
}
auto newIf =
rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes, op.getCondition());
auto newIf = rewriter.create<PrimIfOp>(op->getLoc(), newResultTypes,
op.getCondition());
rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(),
newIf.getThenRegion().end());
rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(),
newIf.getElseRegion().end());
newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase);
newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase);
newIf.getThenRegion().front().getTerminator()->eraseOperands(
resultsToErase);
newIf.getElseRegion().front().getTerminator()->eraseOperands(
resultsToErase);
SmallVector<Value> replacementValues;
for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) {
if (resultsToErase[i])
Expand All @@ -514,8 +518,8 @@ void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
return failure();

if (value) {
rewriter.eraseOp(op);
return success();
rewriter.eraseOp(op);
return success();
}
// Even if the condition is statically false, the assert might never be
// executed.
Expand Down Expand Up @@ -872,10 +876,10 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
auto rhs = op.getOther();
auto getRhsDevice = rewriter.create<PrimDeviceOp>(op.getLoc(), rhs);
auto getRhsDtype = rewriter.create<PrimDtypeOp>(op.getLoc(), rhs);
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
op, op.getType(), lhs, getRhsDevice.getResult(),
getRhsDtype.getResult(), op.getNonBlocking(),
op.getCopy(), op.getMemoryFormat());
rewriter.replaceOpWithNewOp<AtenToDeviceOp>(
op, op.getType(), lhs, getRhsDevice.getResult(),
getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(),
op.getMemoryFormat());
return success();
});
}
Expand All @@ -895,6 +899,51 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) {
return getOperand(0);
}

//===----------------------------------------------------------------------===//
// AtenAsStrideOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenAsStridedOp::fold(FoldAdaptor adaptor) {
auto inputType = getOperand(0).getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes())
return nullptr;

auto outType = getType().dyn_cast<BaseTensorType>();
if (!outType || !outType.hasSizes())
return nullptr;

int64_t storageOffset;
if (!getStorageOffset().getType().isa<Torch::NoneType>())
if (!matchPattern(getStorageOffset(), m_TorchConstantInt(&storageOffset)) ||
storageOffset != 0)
return nullptr;

// Check if the shapes of input tensor and output tensor are totally same.
ArrayRef<int64_t> inputSizes = inputType.getSizes();
ArrayRef<int64_t> outSizes = outType.getSizes();
if (inputSizes.size() != outSizes.size())
return nullptr;
for (int i = 0, e = inputSizes.size(); i < e; ++i) {
if (inputSizes[i] != outSizes[i])
return nullptr;
}

// Check if the elements of output tensor are fetched sequentially from input
// tensor's storage.
SmallVector<int64_t> strides;
if (!matchPattern(getStride(), m_TorchListOfConstantInts(strides)))
return nullptr;

if (strides.size() != inputSizes.size() || strides[strides.size() - 1] != 1)
return nullptr;
for (int i = inputSizes.size() - 2; i >= 0; --i) {
if (strides[i] != inputSizes[i + 1] * strides[i + 1])
return nullptr;
}

return getOperand(0);
}

//===----------------------------------------------------------------------===//
// PrimsViewOfOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1785,7 +1834,7 @@ void Torch::ConstantFloatOp::getAsmResultNames(
// float string representation).
SmallVector<char> buf;
getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0,
/*TruncateZero=*/false);
/*TruncateZero=*/false);
auto isValidMLIRIdentifierChar = [](char c) {
return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' ||
c == '-';
Expand Down Expand Up @@ -1896,7 +1945,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns(
// compiler treat the size as having value semantics?
// There's a small number of such ops, and they are marked as `inplace_view`
// in PyTorch's `native_functions.yaml` file.
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(), op.getIdx());
rewriter.replaceOpWithNewOp<AtenSizeIntOp>(op, sizeOp.getSelf(),
op.getIdx());
return success();
});
}
Expand Down Expand Up @@ -1924,11 +1974,13 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) {
void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) {
auto lhsListConstruct = op.getA().getDefiningOp<Torch::PrimListConstructOp>();
auto lhsListConstruct =
op.getA().getDefiningOp<Torch::PrimListConstructOp>();
if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct))
return failure();

auto rhsListConstruct = op.getB().getDefiningOp<Torch::PrimListConstructOp>();
auto rhsListConstruct =
op.getB().getDefiningOp<Torch::PrimListConstructOp>();
if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct))
return failure();

Expand Down Expand Up @@ -2046,7 +2098,8 @@ LogicalResult PrimTupleConstructOp::verify() {
void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct =
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();

Expand Down Expand Up @@ -2096,7 +2149,8 @@ void PrimUninitializedOp::getCanonicalizationPatterns(
void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) {
auto tupleConstruct = op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
auto tupleConstruct =
op.getTup().getDefiningOp<Torch::PrimTupleConstructOp>();
if (!tupleConstruct)
return failure();

Expand Down Expand Up @@ -2242,9 +2296,7 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef<Attribute> operands,
// AtenAliasOp
//===----------------------------------------------------------------------===//

OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) {
return getOperand();
}
OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); }

//===----------------------------------------------------------------------===//
// AtenFloordivIntOp
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6824,6 +6824,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" func.func @\"__torch_mlir_shape_fn.aten.resize_\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.as_strided\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.list<int> {\n"
" return %arg1 : !torch.list<int>\n"
" }\n"
" func.func @\"__torch_mlir_shape_fn.aten.max_pool2d\"(%arg0: !torch.list<int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.list<int>, %arg4: !torch.list<int>, %arg5: !torch.bool) -> !torch.list<int> {\n"
" %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool) -> !torch.list<int>\n"
" return %0 : !torch.list<int>\n"
Expand Down Expand Up @@ -8520,6 +8523,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() {
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.as_strided\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>, %arg3: !torch.optional<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
" }\n"
" func.func @\"__torch_mlir_dtype_fn.aten.roll\"(%arg0: !torch.tuple<int, int>, %arg1: !torch.list<int>, %arg2: !torch.list<int>) -> !torch.int {\n"
" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple<int, int> -> !torch.int, !torch.int\n"
" return %0#1 : !torch.int\n"
Expand Down
115 changes: 115 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4585,6 +4585,120 @@ class DecomposeAtenSignOp : public OpRewritePattern<AtenSignOp> {
};
} // namespace

namespace {
// Decompose `aten.as_strided` into `aten.flatten.using_ints`, `aten.gather` and
// `aten.view`.
// This decomposition is a little hacky. Since aten.as_strided is a
// view-like op, while aten.gather is not a view-like op. But it might be okay
// in torch-mlir, since we've already assume view-like ops to be of value
// semantics.
class DecomposeAtenAsStridedOp : public OpRewritePattern<AtenAsStridedOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenAsStridedOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto inputType = op.getSelf().getType().dyn_cast<BaseTensorType>();
if (!inputType || !inputType.hasSizes()) {
return rewriter.notifyMatchFailure(
op, "only handle input tensor with shape information");
}
auto outType = op.getType();
ArrayRef<int64_t> inputSizes = inputType.getSizes();

SmallVector<int64_t> outSizes;
if (!matchPattern(op.getSize(), m_TorchListOfConstantInts(outSizes))) {
return rewriter.notifyMatchFailure(
op, "out size must be a list of constant integers");
}
SmallVector<int64_t> strides;
if (!matchPattern(op.getStride(), m_TorchListOfConstantInts(strides))) {
return rewriter.notifyMatchFailure(
op, "strides must be a list of constant integers");
}
if (strides.size() != outSizes.size()) {
return rewriter.notifyMatchFailure(op,
"the stride size is expected to be "
"the same as that of output tensor");
}
int64_t storageOffset = 0;
if (!op.getStorageOffset().getType().isa<Torch::NoneType>() &&
!matchPattern(op.getStorageOffset(),
m_TorchConstantInt(&storageOffset))) {
return rewriter.notifyMatchFailure(
op, "storage offset must be a constant integer");
}

int64_t inputTotalSize = 1;
for (auto inputDimSize : inputSizes) {
inputTotalSize *= inputDimSize;
}

auto flattenInputType = inputType.getWithSizesAndDtype(
SmallVector<int64_t>(1, inputTotalSize), inputType.getOptionalDtype());
Value startDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value endDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(inputSizes.size() - 1));
Value flattenInput = rewriter.create<Torch::AtenFlattenUsingIntsOp>(
loc, flattenInputType, op.getSelf(), startDim, endDim);

int64_t outTotalSize = 1;
for (auto outDimSize : outSizes) {
outTotalSize *= outDimSize;
}

// Generate index for gather op;
DenseSet<int64_t> visitedStoragePlaces;
SmallVector<int64_t> gatherIndicies;
for (int64_t i = 0; i < outTotalSize; ++i) {
int64_t newI = i;
int64_t index = 0;
for (int64_t d = outSizes.size() - 1; d >= 0; --d) {
index += newI % outSizes[d] * strides[d];
newI /= outSizes[d];
}

// We can not handle the situation when the view is "overlapped"
// Ref: https://pytorch.org/docs/stable/generated/torch.as_strided.html
if (visitedStoragePlaces.find(index) != visitedStoragePlaces.end()) {
return rewriter.notifyMatchFailure(
op, "multiple indices of new tensor are mapped to the same storage "
"location");
}

visitedStoragePlaces.insert(index);
gatherIndicies.push_back(index + storageOffset);
}
auto gatherOpType = inputType.getWithSizesAndDtype(
SmallVector<int64_t>(1, outTotalSize), inputType.getOptionalDtype());
Value gatherDim = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(0));
Value gatherIndex = rewriter.create<Torch::ValueTensorLiteralOp>(
loc, DenseIntElementsAttr::get(
RankedTensorType::get(static_cast<int64_t>(outTotalSize),
rewriter.getI64Type()),
gatherIndicies));
Value sparseGrad = rewriter.create<Torch::ConstantBoolOp>(loc, false);
auto gatherOp = rewriter.create<Torch::AtenGatherOp>(
loc, gatherOpType, flattenInput, gatherDim, gatherIndex, sparseGrad);

SmallVector<Value> viewSizes;
for (auto outDimSize : outSizes) {
viewSizes.push_back(rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(outDimSize)));
}
auto viewSizesListConstructOp = rewriter.create<Torch::PrimListConstructOp>(
loc, Torch::ListType::get(rewriter.getType<Torch::IntType>()),
ValueRange(viewSizes));

rewriter.replaceOpWithNewOp<AtenViewOp>(op, op.getType(), gatherOp,
viewSizesListConstructOp);
return success();
}
};
} // namespace

namespace {
class DecomposeComplexOpsPass
: public DecomposeComplexOpsBase<DecomposeComplexOpsPass> {
Expand Down Expand Up @@ -4754,6 +4868,7 @@ class DecomposeComplexOpsPass
addPatternIfTargetOpIsIllegal<DecomposeAtenScalarTensor>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenScatterValueOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenSignOp>(patterns);
addPatternIfTargetOpIsIllegal<DecomposeAtenAsStridedOp>(patterns);

GreedyRewriteConfig config;
config.useTopDownTraversal = true;
Expand Down
Loading

0 comments on commit 1e04260

Please sign in to comment.