Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Fix dot product result type setting
Browse files Browse the repository at this point in the history
  • Loading branch information
TanyoKwok committed Sep 19, 2022
1 parent c1b556a commit bfc5d4e
Showing 1 changed file with 33 additions and 17 deletions.
50 changes: 33 additions & 17 deletions lib/Conversion/TorchToMhlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,8 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
return result.getResult();
}

void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs,
Value &rhs, int64_t lhsContractingDim,
RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs,
Value &rhs, int64_t lhsResultDim, int64_t rhsResultDim, int64_t lhsContractingDim,
int64_t rhsContractingDim) {
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();
Expand All @@ -98,6 +98,16 @@ void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs,
rhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, rhs);
}
}
SmallVector<int64_t> outShape;
outShape.append(lhsShape.begin(), lhsShape.end());
for (int k = 0; k < lhsShape.size(); ++ k) {
if (outShape[k] == ShapedType::kDynamicSize && rhsShape[k] > 0) {
outShape[k] = rhsShape[k];
}
}
outShape[lhsResultDim] = lhsShape[lhsResultDim];
outShape[rhsResultDim] = rhsShape[rhsResultDim];
return RankedTensorType::get(outShape, lhsTy.getElementType());
}

void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
Expand Down Expand Up @@ -212,6 +222,9 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
options.dimSizeIndexBits);
}
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));

auto lhsResultDim = nBatchDims;
auto rhsResultDim = nBatchDims + 1;
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;
if (lhsRank == 1)
Expand All @@ -224,17 +237,18 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();

castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim,
auto outTy = castContractingDim(rewriter, op, lhs, rhs,
lhsResultDim, rhsResultDim, lhsContractingDim,
rhsContractingDim);
output = rewriter
.create<mhlo::DotGeneralOp>(op->getLoc(), resultTy, lhs, rhs,
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr)
.getResult();

auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();
output = rewriter
.create<tensor::CastOp>(op->getLoc(), resultTy, output).getResult();
return success();
}

Expand Down Expand Up @@ -386,10 +400,14 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
auto nBatchDims = resultRank - 2;
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));

auto lhsResultDim = nBatchDims;
auto rhsResultDim = nBatchDims + 1;
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;

castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim,
auto outTy = castContractingDim(rewriter, op, lhs, rhs,
lhsResultDim, rhsResultDim, lhsContractingDim,
rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
Expand All @@ -398,24 +416,22 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});

auto resultTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());

Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);

Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) {
// Bias addition broadcasts to the matmul output shape.
matmulPlusBias =
rewriter
.create<chlo::BroadcastAddOp>(op->getLoc(), resultTy,
.create<chlo::BroadcastAddOp>(op->getLoc(), outTy,
matmulOutput, bias, nullptr)
.getResult();
}

rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, resultTy, matmulPlusBias);
auto resultTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
return success();
}
};
Expand Down

0 comments on commit bfc5d4e

Please sign in to comment.