Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revert "[MHLO] reimplement linear lowering torchToMhlo" #1744

Merged
merged 1 commit into from
Dec 22, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 32 additions & 56 deletions lib/Conversion/TorchToMhlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();

if (lhsRank < 1)
return op.emitError("aten.Linear called but input rank 0");
if (rhsRank != 2)
return op.emitError("aten.Linear called but weight rank not 2");
if (lhsRank != 2 && lhsRank != 3)
return op.emitError("aten.Linear called but input rank not 2 or 3");
if (rhsRank != 2 && rhsRank != 3)
return op.emitError("aten.Linear called but weight rank not 2 or 3");

return success();
}
Expand Down Expand Up @@ -406,59 +406,33 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {

auto lhsTy = lhs.getType().cast<RankedTensorType>();
auto rhsTy = rhs.getType().cast<RankedTensorType>();
auto lhsRank = lhsTy.getRank();
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
rhsTy.getRank() - lhsTy.getRank());

const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
auto nBatchDims = resultRank - 2;
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));

auto lhsResultDim = nBatchDims;
auto rhsResultDim = nBatchDims + 1;
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;

auto loc = op->getLoc();
Value dotLhs;
SmallVector<Value> resultDims;
// vector * matrix or matrix * matrix can directly use mhlo.dot_general
if (lhsTy.getRank() <= 2) {
dotLhs = lhs;
} else {
// Broadcast weight and then use bmm would lead to too much data copy,
// and more compute, then decreace the performance
// Instead, reshape input to 2-D tensor, then use dot to perform
// matrix-matrix multiply, and finnaly reshape to the output shape,
// would get better performance

// [x_1, x_2, ..., x_n, in_features] * [in_features, out_features]
// -> [x_1 * x_2 * ... * x_n , in_features] * [in_features, out_features]
auto dotLhsTy = RankedTensorType::get(
{ShapedType::kDynamicSize, lhsTy.getShape()[lhsRank - 1]},
lhsTy.getElementType());
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));

for (int i = 0; i < lhsRank - 1; ++i) {
Value dimValue = rewriter.create<tensor::DimOp>(loc, lhs, i);
resultDims.push_back(dimValue);
numel = rewriter.create<arith::MulIOp>(
loc, numel,
rewriter.create<arith::IndexCastOp>(loc, intType, dimValue));
}
Value lhsLastRankDim = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, lhs, lhsRank - 1));
resultDims.push_back(rewriter.create<tensor::DimOp>(loc, rhs, 1));
Value reshapeDim =
rewriter
.create<mlir::tensor::FromElementsOp>(
op->getLoc(), ValueRange{numel, lhsLastRankDim})
.getResult();
dotLhs = rewriter.create<mhlo::DynamicReshapeOp>(loc, dotLhsTy, lhs,
reshapeDim);
}
Value matmulOutput =
rewriter.create<mhlo::DotOp>(loc, dotLhs, rhs, nullptr);
auto outTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
// reshape to [x_1, x_2, ..., x_n, out_features]
if (dotLhs != lhs) {
matmulOutput = rewriter.create<mhlo::DynamicReshapeOp>(
loc, outTy, matmulOutput,
rewriter.create<mlir::tensor::FromElementsOp>(loc, resultDims));
}
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);

Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) {
Expand All @@ -469,7 +443,9 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
.getResult();
}

rewriter.replaceOp(op, matmulPlusBias);
auto resultTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
return success();
}
};
Expand Down
57 changes: 0 additions & 57 deletions test/Conversion/TorchToMhlo/linear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -499,60 +499,3 @@ func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
return %3 : !torch.vtensor<[1,4,15,15],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear(
// CHECK-NOT: mhlo.dynamic_reshape
// CHECK: mhlo.transpose
// CHECK: mhlo.dot
// CHECK: chlo.broadcast_add
func.func @torch.aten.linear(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[4,5],f32> {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[4,5],f32>
return %1 : !torch.vtensor<[4,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear$nobias(
// CHECK-NOT: mhlo.dynamic_reshape
// CHECK: mhlo.transpose
// CHECK: mhlo.dot
// CHECK-NOT: chlo.broadcast_add
func.func @torch.aten.linear$nobias(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[4,5],f32> {
%none = torch.constant.none
%1 = torch.aten.linear %arg0, %arg1, %none : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.none -> !torch.vtensor<[4,5],f32>
return %1 : !torch.vtensor<[4,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear$dynamic(
// CHECK: mhlo.transpose
// CHECK: arith.muli
// CHECK: arith.muli
// CHECK: tensor.from_elements
// CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dot
// CHECK: mhlo.dynamic_reshape
// CHECK: chlo.broadcast_add
func.func @torch.aten.linear$dynamic(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,5],f32> {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,5],f32>
return %1 : !torch.vtensor<[?,?,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear$dynamic4D(
// CHECK: mhlo.transpose
// CHECK: arith.muli
// CHECK: arith.muli
// CHECK: tensor.from_elements
// CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dot
// CHECK: mhlo.dynamic_reshape
// CHECK: chlo.broadcast_add
func.func @torch.aten.linear$dynamic4D(%arg0: !torch.vtensor<[?,?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,?,5],f32> {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,?,5],f32>
return %1 : !torch.vtensor<[?,?,?,5],f32>
}