Skip to content

Commit

Permalink
Revert "reimplement linear lowering torchToMhlo (#1524)" (#1744)
Browse files Browse the repository at this point in the history
This reverts commit 50b5245.
  • Loading branch information
Tanyo Kwok authored Dec 22, 2022
1 parent 60a1392 commit 297fd3a
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 113 deletions.
88 changes: 32 additions & 56 deletions lib/Conversion/TorchToMhlo/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,10 +373,10 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
auto lhsRank = lhsTy.getRank();
auto rhsRank = rhsTy.getRank();

if (lhsRank < 1)
return op.emitError("aten.Linear called but input rank 0");
if (rhsRank != 2)
return op.emitError("aten.Linear called but weight rank not 2");
if (lhsRank != 2 && lhsRank != 3)
return op.emitError("aten.Linear called but input rank not 2 or 3");
if (rhsRank != 2 && rhsRank != 3)
return op.emitError("aten.Linear called but weight rank not 2 or 3");

return success();
}
Expand Down Expand Up @@ -406,59 +406,33 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {

auto lhsTy = lhs.getType().cast<RankedTensorType>();
auto rhsTy = rhs.getType().cast<RankedTensorType>();
auto lhsRank = lhsTy.getRank();
auto leadingRank = std::max(lhsTy.getRank() - rhsTy.getRank(),
rhsTy.getRank() - lhsTy.getRank());

const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
getBmmBroadcast(rewriter, op, lhs, rhs, leadingRank,
options.dimSizeIndexBits);
auto resultRank = std::max(lhsTy.getRank(), rhsTy.getRank());
auto nBatchDims = resultRank - 2;
auto batchDims = llvm::to_vector<4>(llvm::seq<int64_t>(0, nBatchDims));

auto lhsResultDim = nBatchDims;
auto rhsResultDim = nBatchDims + 1;
auto lhsContractingDim = nBatchDims + 1;
auto rhsContractingDim = nBatchDims;

auto loc = op->getLoc();
Value dotLhs;
SmallVector<Value> resultDims;
// vector * matrix or matrix * matrix can directly use mhlo.dot_general
if (lhsTy.getRank() <= 2) {
dotLhs = lhs;
} else {
// Broadcast weight and then use bmm would lead to too much data copy,
// and more compute, then decreace the performance
// Instead, reshape input to 2-D tensor, then use dot to perform
// matrix-matrix multiply, and finnaly reshape to the output shape,
// would get better performance

// [x_1, x_2, ..., x_n, in_features] * [in_features, out_features]
// -> [x_1 * x_2 * ... * x_n , in_features] * [in_features, out_features]
auto dotLhsTy = RankedTensorType::get(
{ShapedType::kDynamicSize, lhsTy.getShape()[lhsRank - 1]},
lhsTy.getElementType());
const auto &options = ConvertAtenOp<AtenOpT>::getOptions();
Type intType = rewriter.getIntegerType(options.dimSizeIndexBits);
Value numel = rewriter.create<arith::ConstantOp>(
loc, rewriter.getIntegerAttr(intType, 1));

for (int i = 0; i < lhsRank - 1; ++i) {
Value dimValue = rewriter.create<tensor::DimOp>(loc, lhs, i);
resultDims.push_back(dimValue);
numel = rewriter.create<arith::MulIOp>(
loc, numel,
rewriter.create<arith::IndexCastOp>(loc, intType, dimValue));
}
Value lhsLastRankDim = rewriter.create<arith::IndexCastOp>(
loc, intType, rewriter.create<tensor::DimOp>(loc, lhs, lhsRank - 1));
resultDims.push_back(rewriter.create<tensor::DimOp>(loc, rhs, 1));
Value reshapeDim =
rewriter
.create<mlir::tensor::FromElementsOp>(
op->getLoc(), ValueRange{numel, lhsLastRankDim})
.getResult();
dotLhs = rewriter.create<mhlo::DynamicReshapeOp>(loc, dotLhsTy, lhs,
reshapeDim);
}
Value matmulOutput =
rewriter.create<mhlo::DotOp>(loc, dotLhs, rhs, nullptr);
auto outTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
// reshape to [x_1, x_2, ..., x_n, out_features]
if (dotLhs != lhs) {
matmulOutput = rewriter.create<mhlo::DynamicReshapeOp>(
loc, outTy, matmulOutput,
rewriter.create<mlir::tensor::FromElementsOp>(loc, resultDims));
}
castContractingDim(rewriter, op, lhs, rhs, lhsResultDim, rhsResultDim,
lhsContractingDim, rhsContractingDim);
mhlo::DotDimensionNumbersAttr dotDimensionNumbers =
mhlo::DotDimensionNumbersAttr::get(
rewriter.getContext(),
/*lhsBatchingDimensions=*/batchDims,
/*rhsBatchingDimensions=*/batchDims,
/*lhsContractingDimensions=*/{lhsContractingDim},
/*rhsContractingDimensions=*/{rhsContractingDim});
Value matmulOutput = rewriter.create<mhlo::DotGeneralOp>(
op->getLoc(), outTy, lhs, rhs, dotDimensionNumbers, nullptr);

Value matmulPlusBias = matmulOutput;
if (!biasTy.template isa<Torch::NoneType>()) {
Expand All @@ -469,7 +443,9 @@ class ConvertAtenLinearOp : public ConvertAtenMatmulBaseOp<AtenOpT> {
.getResult();
}

rewriter.replaceOp(op, matmulPlusBias);
auto resultTy =
ConvertAtenOp<AtenOpT>::getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultTy, matmulPlusBias);
return success();
}
};
Expand Down
57 changes: 0 additions & 57 deletions test/Conversion/TorchToMhlo/linear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -499,60 +499,3 @@ func.func @torch.aten.convolution$transposed_groups(%arg0: !torch.vtensor<[1,2,7
%3 = torch.aten.convolution %arg0, %arg1, %none, %2, %0, %1, %true, %0, %int2 : !torch.vtensor<[1,2,7,7],f32>, !torch.vtensor<[2,2,3,3],f32>, !torch.none, !torch.list<int>, !torch.list<int>, !torch.list<int>, !torch.bool, !torch.list<int>, !torch.int -> !torch.vtensor<[1,4,15,15],f32>
return %3 : !torch.vtensor<[1,4,15,15],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear(
// CHECK-NOT: mhlo.dynamic_reshape
// CHECK: mhlo.transpose
// CHECK: mhlo.dot
// CHECK: chlo.broadcast_add
func.func @torch.aten.linear(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[4,5],f32> {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[4,5],f32>
return %1 : !torch.vtensor<[4,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear$nobias(
// CHECK-NOT: mhlo.dynamic_reshape
// CHECK: mhlo.transpose
// CHECK: mhlo.dot
// CHECK-NOT: chlo.broadcast_add
func.func @torch.aten.linear$nobias(%arg0: !torch.vtensor<[4,3],f32>, %arg1: !torch.vtensor<[5,3],f32>) -> !torch.vtensor<[4,5],f32> {
%none = torch.constant.none
%1 = torch.aten.linear %arg0, %arg1, %none : !torch.vtensor<[4,3],f32>, !torch.vtensor<[5,3],f32>, !torch.none -> !torch.vtensor<[4,5],f32>
return %1 : !torch.vtensor<[4,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear$dynamic(
// CHECK: mhlo.transpose
// CHECK: arith.muli
// CHECK: arith.muli
// CHECK: tensor.from_elements
// CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dot
// CHECK: mhlo.dynamic_reshape
// CHECK: chlo.broadcast_add
func.func @torch.aten.linear$dynamic(%arg0: !torch.vtensor<[?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,5],f32> {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,5],f32>
return %1 : !torch.vtensor<[?,?,5],f32>
}

// -----

// CHECK-LABEL: func.func @torch.aten.linear$dynamic4D(
// CHECK: mhlo.transpose
// CHECK: arith.muli
// CHECK: arith.muli
// CHECK: tensor.from_elements
// CHECK: mhlo.dynamic_reshape
// CHECK: mhlo.dot
// CHECK: mhlo.dynamic_reshape
// CHECK: chlo.broadcast_add
func.func @torch.aten.linear$dynamic4D(%arg0: !torch.vtensor<[?,?,?,3],f32>, %arg1: !torch.vtensor<[5,3],f32>, %arg2: !torch.vtensor<[5],f32>) -> !torch.vtensor<[?,?,?,5],f32> {
%1 = torch.aten.linear %arg0, %arg1, %arg2 : !torch.vtensor<[?,?,?,3],f32>, !torch.vtensor<[5,3],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[?,?,?,5],f32>
return %1 : !torch.vtensor<[?,?,?,5],f32>
}

0 comments on commit 297fd3a

Please sign in to comment.