Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Fix dot product result type setting (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok authored and bladedisc committed Sep 23, 2022
1 parent d226b00 commit 3858cc8
Showing 1 changed file with 55 additions and 25 deletions.
80 changes: 55 additions & 25 deletions lib/Conversion/TorchToMhlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,9 +71,11 @@ Value getPermutedTensor(PatternRewriter &rewriter, Operation *op, Value input,
return result.getResult();
}

void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs,
Value &rhs, int64_t lhsContractingDim,
int64_t rhsContractingDim) {
RankedTensorType castContractingDim(PatternRewriter &rewriter, Operation *op,
Value &lhs, Value &rhs,
int64_t lhsResultDim, int64_t rhsResultDim,
int64_t lhsContractingDim,
int64_t rhsContractingDim) {
auto lhsTy = lhs.getType().dyn_cast<RankedTensorType>();
auto rhsTy = rhs.getType().dyn_cast<RankedTensorType>();

Expand All @@ -98,6 +100,32 @@ void castContractingDim(PatternRewriter &rewriter, Operation *op, Value &lhs,
rhs = rewriter.create<tensor::CastOp>(op->getLoc(), newRankTy, rhs);
}
}
SmallVector<int64_t> outShape;
// set batch dims, will skip invalid dimensions
for (int k = 0; k < lhsShape.size(); ++k) {
if (k == lhsResultDim || k == lhsContractingDim)
continue;
outShape.push_back(lhsShape[k]);
}
for (int k = 0, b = 0; k < rhsShape.size(); ++k) {
if (b >= outShape.size())
break;
if (k == rhsResultDim || k == rhsContractingDim)
continue;
if (outShape[b] == ShapedType::kDynamicSize && rhsShape[k] >= 0) {
outShape[b] = rhsShape[k];
}
b++;
}

// set result dimensions
if (lhsResultDim < lhsShape.size() && lhsResultDim >= 0) {
outShape.push_back(lhsShape[lhsResultDim]);
}
if (rhsResultDim < rhsShape.size() && rhsResultDim >= 0) {
outShape.push_back(rhsShape[rhsResultDim]);
}
return RankedTensorType::get(outShape, lhsTy.getElementType());
}

void getBmmBroadcast(PatternRewriter &rewriter, Operation *op, Value &inpLhs,
Expand Down Expand Up @@ -212,10 +240,15 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
options.dimSizeIndexBits);
}
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));

auto lhsResultDim = nBatchDims;
auto rhsResultDim = nBatchDims + 1;
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;
if (lhsRank == 1)
if (lhsRank == 1) {
lhsResultDim = nBatchDims + 1;
lhsContractingDim = nBatchDims;
}

mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
Expand All @@ -224,17 +257,13 @@ class ConvertAtenMatmulBaseOp : public ConvertAtenOp<AtenOpT> {
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
auto resultTy = ConvertAtenOp<AtenOpT>::getTypeConverter()
->convertType(op.getType())
.template cast<RankedTensorType>();

castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim,
rhsContractingDim);
auto outTy =
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
output = rewriter
.create<mhlo::DotGeneralOp>(op->getLoc(), resultTy, lhs, rhs,
.create<mhlo::DotGeneralOp>(op->getLoc(), outTy, lhs, rhs,
dotDimensionNumbers, nullptr)
.getResult();

return success();
}

Expand Down Expand Up @@ -386,36 +415,37 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
auto nBatchDims = resultRank - 2;
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));

auto lhsResultDim = nBatchDims;
auto rhsResultDim = nBatchDims + 1;
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;

castContractingDim(rewriter, op, lhs, rhs, lhsContractingDim,
rhsContractingDim);
auto outTy =
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});

auto resultTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());

Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), resultTy, lhs, rhs, dotDimensionNumbers, nullptr);
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);

Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) {
// Bias addition broadcasts to the matmul output shape.
matmulPlusBias =
rewriter
.create<chlo::BroadcastAddOp>(op->getLoc(), resultTy,
matmulOutput, bias, nullptr)
.getResult();
matmulPlusBias = rewriter
.create<chlo::BroadcastAddOp>(
op->getLoc(), outTy, matmulOutput, bias, nullptr)
.getResult();
}

rewriter.replaceOpWithNewOp<mhlo::ConvertOp>(op, resultTy, matmulPlusBias);
auto resultTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
return success();
}
};
Expand Down

0 comments on commit 3858cc8

Please sign in to comment.