Skip to content

Commit

Permalink
support dynamic dims for decomposeAtenEyeOp
Browse files Browse the repository at this point in the history
  • Loading branch information
Xida Ren committed Apr 22, 2024
1 parent ec9fdbc commit 12094d4
Showing 1 changed file with 25 additions and 22 deletions.
47 changes: 25 additions & 22 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include <cstdint>
#include <optional>
#include <set>

using namespace mlir;
Expand Down Expand Up @@ -1059,44 +1060,46 @@ class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
int64_t n;

if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (n < 0) {
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}

Value none = rewriter.create<ConstantNoneOp>(loc);
auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);

int64_t n;
Type rangeNType;
if (matchPattern(op.getN(), m_TorchConstantInt(&n))) {
rangeNType = outType.getWithSizesAndDtype(std::nullopt, si64Type);
} else {
if (n < 0)
return rewriter.notifyMatchFailure(op,
"n must be greater or equal to 0");
rangeNType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);
}
Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none);

auto arangeType1 =
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
int64_t m;
Type rangeMType;
if (matchPattern(op.getM(), m_TorchConstantInt(&m))) {
rangeMType = outType.getWithSizesAndDtype(std::nullopt, si64Type);
} else {
if (m < 0)
return rewriter.notifyMatchFailure(op,
"m must be greater or equal to 0");
rangeMType = outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
}
Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
Expand Down

0 comments on commit 12094d4

Please sign in to comment.