Skip to content

Commit

Permalink
Fix failing test
Browse files Browse the repository at this point in the history
  • Loading branch information
vivekkhandelwal1 committed Dec 10, 2024
1 parent fd771f0 commit 99943ef
Showing 1 changed file with 9 additions and 0 deletions.
9 changes: 9 additions & 0 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1313,6 +1313,15 @@ class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv,
resultElementType);
}
if (is1DGroupConv) {
// Squeezing the last dim of the result of conv.
auto squeezeInputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1);
if (failed(squeezeInputInfo)) {
return rewriter.notifyMatchFailure(
op, "cannot generate unsqueeze tensor");
}
conv = squeezeInputInfo.value();
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, conv);
return success();
}
Expand Down

0 comments on commit 99943ef

Please sign in to comment.