diff --git a/lib/Conversion/TorchToMhlo/Linear.cpp b/lib/Conversion/TorchToMhlo/Linear.cpp index 7c98a08504781..e53d4a5de48aa 100644 --- a/lib/Conversion/TorchToMhlo/Linear.cpp +++ b/lib/Conversion/TorchToMhlo/Linear.cpp @@ -71,8 +71,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input, return result.getResult(); } -void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs, - Value &rhs, int64_t lhsContractingDim, +RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs, + Value &rhs, int64_t lhsResultDim, int64_t rhsResultDim, int64_t lhsContractingDim, int64_t rhsContractingDim) { auto lhsTy = lhs.getType().dyn_cast(); auto rhsTy = rhs.getType().dyn_cast(); @@ -98,6 +98,16 @@ void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs, rhs = rewriter.create(op->getLoc(), newRankTy, rhs); } } + SmallVector outShape; + outShape.append(lhsShape.begin(), lhsShape.end()); + for (int k = 0; k < lhsShape.size(); ++ k) { + if (outShape[k] == ShapedType::kDynamicSize && rhsShape[k] > 0) { + outShape[k] = rhsShape[k]; + } + } + outShape[lhsResultDim] = lhsShape[lhsResultDim]; + outShape[rhsResultDim] = rhsShape[rhsResultDim]; + return RankedTensorType::get(outShape, lhsTy.getElementType()); } void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs, @@ -212,6 +222,9 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { options.dimSizeIndexBits); } auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; auto lhsContractingDim = nBatchDims + 1; auto rhsContractingDim = nBatchDims; if (lhsRank == 1) @@ -224,17 +237,18 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp { /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - auto resultTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); - - castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim, + auto outTy = castContractingDim(rewriter, op, lhs, rhs, + lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); output = rewriter - .create(op->getLoc(), resultTy, lhs, rhs, + .create(op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr) .getResult(); - + auto resultTy = ConvertAtenOp::getTypeConverter() + ->convertType(op.getType()) + .template cast(); + output = rewriter + .create(op->getLoc(), resultTy, output).getResult(); return success(); } @@ -386,10 +400,14 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank()); auto nBatchDims = resultRank - 2; auto batchDims = llvm::to_vector<4>(llvm::seq(0, nBatchDims)); + + auto lhsResultDim = nBatchDims; + auto rhsResultDim = nBatchDims + 1; auto lhsContractingDim = nBatchDims + 1; auto rhsContractingDim = nBatchDims; - castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim, + auto outTy = castContractingDim(rewriter, op, lhs, rhs, + lhsResultDim, rhsResultDim, lhsContractingDim, rhsContractingDim); mhlo::DotDimensionNumbersAttr dotDimensionNumbers = mhlo::DotDimensionNumbersAttr::get( @@ -398,24 +416,22 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp { /*rhsBatchingDimensions=*/batchDims, /*lhsContractingDimensions=*/{lhsContractingDim}, /*rhsContractingDimensions=*/{rhsContractingDim}); - - auto resultTy = - ConvertAtenOp::getTypeConverter()->convertType(op.getType()); - Value matmulOutput = rewriter.create( - op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr); + op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr); Value matmulPlusBias = matmulOutput; if (!biasTy.template isa()) { // Bias addition broadcasts to the matmul output shape. matmulPlusBias = rewriter - .create(op->getLoc(), resultTy, + .create(op->getLoc(), outTy, matmulOutput, bias, nullptr) .getResult(); } - rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); + auto resultTy = + ConvertAtenOp::getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, resultTy, matmulPlusBias); return success(); } };