diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 02f6761de189..dfb00aee209a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -1313,6 +1313,15 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = torch_to_linalg::convertTensorToElementType(rewriter, loc, conv, resultElementType); } + if (is1DGroupConv) { + // Squeezing the last dim of the result of conv. + auto squeezeInputInfo = squeezeTensor(rewriter, op, conv, /*dim=*/-1); + if (failed(squeezeInputInfo)) { + return rewriter.notifyMatchFailure( + op, "cannot generate unsqueeze tensor"); + } + conv = squeezeInputInfo.value(); + } rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); }