Skip to content

Commit

Permalink
Squashed commit of the following:
Browse files Browse the repository at this point in the history
commit bafe339
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon May 8 21:26:56 2023 +0000

    Add dtype functions for aten.atan and prims.squeeze

commit bebf695
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon May 8 21:26:10 2023 +0000

    Remove duplicate code from merge with main

commit 0d11895
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Fri May 5 21:39:02 2023 +0000

    Update LLVM tag

commit 73d5c07
Merge: 899d8bc eaaaeb6
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Fri May 5 21:30:09 2023 +0000

    Merge remote-tracking branch 'upstream/main' into merge-main

commit 899d8bc
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon Mar 13 21:39:14 2023 +0000

    Add dtype functions for `aten.ge.Tensor` and `aten.le.Tensor`

commit f58f9c2
Merge: ce7abf4 4912c39
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon Mar 13 21:32:00 2023 +0000

    Merge branch 'main' into merge-main

commit ce7abf4
Author: Jiahao Li <[email protected]>
Date:   Wed Feb 22 06:54:41 2023 +0800

    Add dtype functions for ops that take dtype from 2nd operand (llvm#1891)

commit 63945a2
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon Feb 13 17:56:09 2023 -0800

    Change dtype functions interface to take ints tuple for each tensor (llvm#1865)

    The original design for the dtype functions outlined in
    llvm#1462 was unable to properly
    handle ops that take optional tensors as an input when the optional
    tensor has a value of None. By the time the op gets imported into
    torch-mlir, if an optional value is None, all information about the
    original type is lost from the op type signature, preventing
    torch-mlir from knowing if a value of None was from an optional tensor
    or not, which was crucial in the original design since each tensor
    argument must be turned into two separate arguments for the dtype
    function.

    This commit changes the interface to dtype functions such that each
    tensor turns into a tuple of two ints, the first representing the rank
    of the tensor and the second the dtype of the tensor. Since now there
    is a one-to-one correspondence between the operands of an op and the
    operands of its dtype function, there is no ambiguity about which
    operand of the op corresponds with which operand of the dtype
    function.

    To test the implementation, this commit defines dtype functions for
    the convolution ops, all of which take one optional tensor as an
    argument.

commit 981ac88
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Wed Feb 1 22:30:27 2023 +0000

    Add dtype functions for two tensor promotion ops (llvm#1831)

    This commit adds dtype functions for ops in RefineTypes under the
    category of "Promote the two dtypes". The only ops not added here are
    convolution ops, since they take an optional tensor argument, and the
    dtype pipeline currently does not correctly handle that case. I will
    add a follow up patch fixing this.

    This commit also adds two helper functions that perform a very
    thorough testing of dtype functions. The helper function
    `_check_two_tensor_op` is able to independently test invalid input
    dtypes and invalid output dtypes.

    Lastly, this commit also XFAILs "MobilenetV3Module_basic".

commit 83d4e89
Author: Jiahao Li <[email protected]>
Date:   Sat Jan 21 02:39:41 2023 +0800

    Add dtype functions for floating point ops (llvm#1813)

commit 8cae5ba
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon Jan 16 14:32:23 2023 -0800

    Add dtype functions for comparison ops (llvm#1806)

    This commit adds dtype functions for comparison ops that always return
    a tensor of dtype `i1`.

commit 5b77c15
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Mon Jan 16 20:27:49 2023 +0000

    Add CI to `dtype-functions-staging` branch

commit ac94ba2
Author: Ramiro Leal-Cavazos <[email protected]>
Date:   Thu Jan 12 22:41:04 2023 +0000

    Move dtype functions into their own section in lib gen file

    In order to easily keep track of the dtype functions that have been
    moved to `abstract_interp_lib_gen.py` and make it easier to add new
    ones, this commit groups all the dtype functions together, rather than
    having them interspersed between the shape functions.
  • Loading branch information
ramiro050 committed May 9, 2023
1 parent 51e0a2c commit 89484b2
Show file tree
Hide file tree
Showing 14 changed files with 4,991 additions and 1,781 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Build and Test

on:
pull_request:
branches: [ main ]
branches: [ main, dtype-functions-staging ]
push:
branches: [ main ]
workflow_dispatch:
Expand Down
2,713 changes: 2,574 additions & 139 deletions lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp

Large diffs are not rendered by default.

701 changes: 1 addition & 700 deletions lib/Dialect/Torch/Transforms/RefineTypes.cpp

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -176,20 +176,22 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

// !torch.union<int, float> is the type used for `Scalar` inputs. At
// compile time, such inputs will usually be resolved to an `int` or a `float`
// so we need to derefine to match the library function signature.
// !torch.union<int, float> or !torch.union<int, float, none> is the type used
// for (optional) `Scalar` inputs. At compile time, such inputs will usually
// be resolved to an `int` or a `float` so we need to derefine to match the
// library function signature.
if (auto unionType = desiredType.dyn_cast<Torch::UnionType>()) {
if (llvm::all_of(unionType.getContainedTypes(), [](Type containedType) {
return containedType.isa<Torch::IntType, Torch::FloatType>();
return containedType
.isa<Torch::IntType, Torch::FloatType, Torch::NoneType>();
}))
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}

// If the operand is NoneType, then we just need to derefine it to the
// optional type in the function signature.
if (operandType.isa<Torch::NoneType>()) {
assert(desiredType.isa<Torch::OptionalType>() &&
assert(!desiredType.isa<Torch::NoneType>() &&
"Don't expect library functions to have NoneType parameters");
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,248 @@
//===----------------------------------------------------------------------===//

#include "SimplifyAbstractInterpCalculationsUtils.h"
#include "mlir/IR/IRMapping.h"
#include "torch-mlir/Dialect/Torch/Utils/Utils.h"

using namespace mlir;
using namespace mlir::torch;
using namespace mlir::torch::Torch;

namespace {
class FoldPrimUncheckedCastOp : public OpRewritePattern<PrimUncheckedCastOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimUncheckedCastOp op,
PatternRewriter &rewriter) const override {
if (!isValidSubtype(op.getX().getType(), op.getResult().getType())) {
return rewriter.notifyMatchFailure(
op, "input tensor type is not a valid subtype of result type");
}
rewriter.replaceOp(op, op.getX());
return success();
}
};
} // namespace

namespace {
// TODO: Only unroll inside the shape calculation region.
// Maybe do this by only applying patterns and folding greedily on the ops
// inside the region + the shape.calculate op itself?
class FullyUnrollPrimLoopOp : public OpRewritePattern<PrimLoopOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimLoopOp op,
PatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
if (!op.isForLike())
return rewriter.notifyMatchFailure(op, "Loop is not for-like");
int64_t maxTripCount;
if (!matchPattern(op.getMaxTripCount(), m_TorchConstantInt(&maxTripCount)))
return rewriter.notifyMatchFailure(
op, "Expected `maxTripCount` to be a constant int");
;
SmallVector<Value> indices;
for (int64_t i = 0; i < maxTripCount; i++) {
// TODO: Add convenience builder.
indices.push_back(rewriter.create<ConstantIntOp>(
loc, rewriter.getIntegerAttr(IntegerType::get(context, 64), i)));
}
Block *beforeBlock = op->getBlock();
Block *afterBlock = rewriter.splitBlock(op->getBlock(), op->getIterator());

SmallVector<Block *> blocksToMerge;
IRMapping bvm;
// TODO: Helper for region().front()
auto condition =
cast<PrimLoopConditionOp>(op.getRegion().front().getTerminator());
for (int64_t i = 0; i < maxTripCount; i++) {
SmallVector<Value> iterArgs;
if (i == 0) {
llvm::append_range(iterArgs, op.getIterArgsInit());
} else {
llvm::append_range(
iterArgs, llvm::map_range(condition.getIterArgs(),
[&](Value v) { return bvm.lookup(v); }));
}
bvm.clear();
bvm.map(op.getRegion().front().getArgument(0), indices[i]);
bvm.map(op.getRegion().front().getArguments().slice(1), iterArgs);

op.getRegion().cloneInto(afterBlock->getParent(),
afterBlock->getIterator(), bvm);
Block *clonedBlock = bvm.lookup(&op.getRegion().front());
rewriter.eraseOp(clonedBlock->getTerminator());
blocksToMerge.push_back(clonedBlock);
}

blocksToMerge.push_back(afterBlock);
for (Block *block : blocksToMerge)
rewriter.mergeBlocks(block, beforeBlock);
if (maxTripCount == 0) {
rewriter.replaceOp(op, op.getIterArgsInit());
} else {
rewriter.replaceOp(op, llvm::to_vector<6>(llvm::map_range(
condition.getIterArgs(),
[&](Value v) { return bvm.lookup(v); })));
}
return success();
}
};
} // namespace

namespace {
class AbstractlyInterpretListOpsWithinABlock
: public OpRewritePattern<PrimListConstructOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(PrimListConstructOp op,
PatternRewriter &rewriter) const override {
Block *block = op->getBlock();
auto allUsers = llvm::to_vector<6>(op->getUsers());

// Sort the users into program order.
auto getParentInBlock = [&](Operation *op) {
while (op->getBlock() != block)
op = op->getParentOp();
return op;
};
// Use a stable sort for deterministic results when users are nested in two
// regions of the same parent op.
llvm::stable_sort(allUsers, [&](Operation *lhs, Operation *rhs) {
return getParentInBlock(lhs)->isBeforeInBlock(getParentInBlock(rhs));
});

// We cannot interpret all ops. So first do a check to see up until which
// point we can interpret.
int numUsersToInterpret = 0;
for (int i = 0, e = allUsers.size(); i != e; i++, numUsersToInterpret++) {
Operation *user = allUsers[i];
// If a user potentially mutates the list, then we require it to be in the
// same block for our simple abstract interpretation to work (we can't,
// for example, handle an "append" operation in a loop or other region).
// However, if the op is read-only, then from the purpose of our abstract
// interpretation, we can handle it effectively as though it was at the
// same position as the corresponding parent op in the block under
// consideration.
if (potentiallyMutatesListOperands(user)) {
if (user->getBlock() != block)
break;
}
}

// Truncate the list of users to the number of users we're going to
// interpret.
allUsers.resize(numUsersToInterpret);
auto usersToInterpret = ArrayRef(allUsers).take_front(numUsersToInterpret);

// For each mutating op (which must be in the same block), we save the
// current state of the list as a vector of Value's. These will then
// be converted to PrimListConstructOp's at the correct program points.
SmallVector<SmallVector<Value>> listLiterals;
SmallVector<Value> runningList;
llvm::append_range(runningList, op->getOperands());
bool generatedNewLiteral = false;
for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
if (!append.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `AtenAppendTOp` to not have users");
if (append.getSelf() == op) {
runningList.push_back(append.getEl());
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
if (!insert.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `AtenInsertTOp` to not have users");
int64_t index;
if (!matchPattern(insert.getIdx(), m_TorchConstantInt(&index)))
return rewriter.notifyMatchFailure(
op, "Expected `idx` of `AtenInsertTOp` to be a constant int");
// The index might be statically out of bounds.
if (index < 0 || index > static_cast<int64_t>(runningList.size()))
return rewriter.notifyMatchFailure(
op, "Index in `AtenInsertTOp` is out of bounds");
if (insert.getSelf() == op) {
runningList.insert(runningList.begin() + index, insert.getEl());
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
if (!setItem.use_empty())
return rewriter.notifyMatchFailure(
op, "Expected `Aten_SetItemTOp` to not have users");
std::optional<int64_t> indexOpt = matchLegalConstantIndexIntoListOfSize(
setItem.getIdx(), runningList.size());
// The index might be statically out of bounds.
if (!indexOpt)
return rewriter.notifyMatchFailure(
op, "Index in `Aten_SetItemTOp` is out of bounds");
if (setItem.getL() == op) {
runningList[*indexOpt] = setItem.getEl();
generatedNewLiteral = true;
}
listLiterals.push_back(runningList);
continue;
}
// If this user potentially mutates the list and isn't handled above, then
// we can't abstractly interpret any further.
if (potentiallyMutatesListOperands(user))
break;
}

if (!generatedNewLiteral)
return rewriter.notifyMatchFailure(op, "No new literal created");

// Rewrite all users to use the appropriate list literals.
Value latestLiteral = rewriter.create<PrimListConstructOp>(
op->getLoc(), op.getType(), op->getOperands());
int nextLiteral = 0;
for (Operation *user : usersToInterpret) {
if (auto append = dyn_cast<AtenAppendTOp>(user)) {
rewriter.setInsertionPoint(append);
latestLiteral = rewriter.create<PrimListConstructOp>(
append->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (append.getSelf() == op)
rewriter.eraseOp(append);
continue;
}
if (auto insert = dyn_cast<AtenInsertTOp>(user)) {
rewriter.setInsertionPoint(insert);
latestLiteral = rewriter.create<PrimListConstructOp>(
insert->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (insert.getSelf() == op)
rewriter.eraseOp(insert);
continue;
}
if (auto setItem = dyn_cast<Aten_SetItemTOp>(user)) {
rewriter.setInsertionPoint(setItem);
latestLiteral = rewriter.create<PrimListConstructOp>(
setItem->getLoc(), op.getType(), listLiterals[nextLiteral++]);
if (setItem.getL() == op)
rewriter.eraseOp(setItem);
continue;
}
for (OpOperand &opOperand : user->getOpOperands()) {
if (opOperand.get() == op.getResult()) {
opOperand.set(latestLiteral);
}
}
}

// Any remaining uses should use the updated value of the latest literal.
rewriter.replaceOp(op, latestLiteral);
return success();
}
};
} // namespace

LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,
int resultNum,
Type newResultType,
Expand Down Expand Up @@ -97,3 +334,18 @@ LogicalResult Torch::updateCalculateOpResultTypes(Operation *calculateOp,

return success();
}

void mlir::torch::Torch::populateFoldPrimUncheckedCastOpPattern(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<FoldPrimUncheckedCastOp>(context);
}

void mlir::torch::Torch::populateFullyUnrollPrimLoopOpPattern(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<FullyUnrollPrimLoopOp>(context);
}

void mlir::torch::Torch::populateAbstractlyInterpretListOpsWithinABlockPattern(
RewritePatternSet &patterns, MLIRContext *context) {
patterns.insert<AbstractlyInterpretListOpsWithinABlock>(context);
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ LogicalResult updateCalculateOpResultTypes(Operation *calculateOp,
int resultNum, Type newResultType,
PatternRewriter &rewriter);

void populateFoldPrimUncheckedCastOpPattern(RewritePatternSet &patterns,
MLIRContext *context);
void populateFullyUnrollPrimLoopOpPattern(RewritePatternSet &patterns,
MLIRContext *context);
void populateAbstractlyInterpretListOpsWithinABlockPattern(
RewritePatternSet &patterns, MLIRContext *context);

} // namespace Torch
} // namespace torch
} // namespace mlir
Expand Down
7 changes: 7 additions & 0 deletions lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,17 @@ class SimplifyDtypeCalculationsPass
MLIRContext *context = &getContext();

RewritePatternSet patterns(context);
populateFullyUnrollPrimLoopOpPattern(patterns, context);
populateAbstractlyInterpretListOpsWithinABlockPattern(patterns, context);
populateFoldPrimUncheckedCastOpPattern(patterns, context);
patterns.insert<RefineDtypeCalculateOp>(context);
patterns.insert<DecomposePromoteDtypesOp>(context);
patterns.insert<RefineNumToTensorScalarOpType>(context);

PrimIfOp::getCanonicalizationPatterns(patterns, context);
Aten__Getitem__TOp::getCanonicalizationPatterns(patterns, context);
PrimTupleUnpackOp::getCanonicalizationPatterns(patterns, context);

// TODO: Debug visitation order to make this more efficient.
// A single linear scan should suffice.
GreedyRewriteConfig config;
Expand Down
Loading

0 comments on commit 89484b2

Please sign in to comment.