Skip to content

Commit

Permalink
[torch] aten.eye should use dynamic dims when no static dims are av…
Browse files Browse the repository at this point in the history
…ailable (#3202)

Co-authored-by: Xida Ren <[email protected]>
  • Loading branch information
renxida and Xida Ren authored Apr 30, 2024
1 parent 72349f7 commit 315dc6c
Showing 1 changed file with 23 additions and 24 deletions.
47 changes: 23 additions & 24 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1059,44 +1059,44 @@ class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
LogicalResult matchAndRewrite(AtenEyeMOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
int64_t n;

if (!matchPattern(op.getN(), m_TorchConstantInt(&n)))
return rewriter.notifyMatchFailure(op,
"unimplemented: n must be constant");
int64_t m;
if (!matchPattern(op.getM(), m_TorchConstantInt(&m)))
return rewriter.notifyMatchFailure(op,
"unimplemented: m must be constant");
Value none = rewriter.create<ConstantNoneOp>(loc);
auto outType = dyn_cast<BaseTensorType>(op.getType());
auto outType = op.getType().dyn_cast<BaseTensorType>();
if (!outType)
return rewriter.notifyMatchFailure(
op, "Only tensor types input are currently supported");
if (!outType.hasDtype()) {
return rewriter.notifyMatchFailure(op, "result should have dtype");
}
if (n < 0) {
return rewriter.notifyMatchFailure(op, "n must be greater or equal to 0");
}
if (m < 0) {
return rewriter.notifyMatchFailure(op, "m must be greater or equal to 0");
}

Value none = rewriter.create<ConstantNoneOp>(loc);
auto context = op.getContext();
auto int64Dtype = getDtypeIntValueForType(
rewriter, loc,
rewriter.getIntegerType(/*width=*/64, /*isSigned=*/true));
auto si64Type = IntegerType::get(context, 64, IntegerType::Signed);
auto arangeType = outType.getWithSizesAndDtype(llvm::ArrayRef(n), si64Type);

int64_t n = kUnknownSize;
int64_t m = kUnknownSize;
// prioritize getting shape from output shape
if (outType.hasSizes() && outType.getSizes().size() == 2) {
n = outType.getSizes().front();
m = outType.getSizes().back();
}
// if output shape is not available, try to get shape from input
if (n == kUnknownSize)
matchPattern(op.getN(), m_TorchConstantInt(&n));
if (m == kUnknownSize)
matchPattern(op.getM(), m_TorchConstantInt(&m));

// prepare two unsqueezed ranges that are equal on and only on the diagonal
auto rangeNSize = llvm::SmallVector<int64_t, 1>({n});
Type rangeNType = outType.getWithSizesAndDtype(rangeNSize, si64Type);
Value rangeN = rewriter.create<AtenArangeOp>(
loc, arangeType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeNType, op.getN(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/op.getDevice(), /*pin_memory=*/none);

auto arangeType1 =
outType.getWithSizesAndDtype(llvm::ArrayRef(m), si64Type);
auto rangeMSize = llvm::SmallVector<int64_t, 1>({m});
Type rangeMType = outType.getWithSizesAndDtype(rangeMSize, si64Type);
Value rangeM = rewriter.create<AtenArangeOp>(
loc, arangeType1, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
loc, rangeMType, op.getM(), /*dtype=*/int64Dtype, /*layout=*/none,
/*device=*/none, /*pin_memory=*/none);

Value constMinusOne = rewriter.create<Torch::ConstantIntOp>(
Expand All @@ -1109,7 +1109,6 @@ class DecomposeAtenEyeMOp : public OpRewritePattern<AtenEyeMOp> {
}
Value unsqzRangeN = *unsqzTensorInfo;

// compare unsqueezed input with boundaries
auto eqType = ValueTensorType::get(
context, cast<BaseTensorType>(op.getType()).getSizes(),
IntegerType::get(context, 1));
Expand Down

0 comments on commit 315dc6c

Please sign in to comment.