Skip to content

Commit

Permalink
[TORCH] Add decomposition of aten.linear op
Browse files Browse the repository at this point in the history
This commit adds decomposition of `aten.linear` op.

Signed-Off-by: Gaurav Shukla
  • Loading branch information
Shukla-Gaurav committed Jul 15, 2022
1 parent 3589134 commit c79cc85
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 172 deletions.
187 changes: 15 additions & 172 deletions lib/Conversion/TorchToLinalg/Linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,21 @@ class ConvertAtenMatmulOp : public OpConversionPattern<AtenMatmulOp> {
op, "unable to perform broadcast operation");
}

if (maxRank == 3) {
Value zeroTensor = createZeroInitTensor(
rewriter, loc,
ValueRange{broadcastedBatchShape[0], lhsDim0, rhsDim1},
elementType);
Value matmul =
rewriter
.create<linalg::BatchMatmulOp>(
loc, zeroTensor.getType(),
ValueRange{broadcastedLhs, broadcastedRhs}, zeroTensor)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}

// Check if the result of the matrix multiplication has more than one
// dynamic batch dimensions.
ArrayRef<int64_t> batchDimsInt = resultType.getShape().drop_back(2);
Expand Down Expand Up @@ -453,176 +468,6 @@ class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
};
} // namespace

namespace {
// See comments at in convertMmOp and the heading for this section for general
// considerations. This function needs to be auto-generated.
class ConvertAtenLinearOp : public OpConversionPattern<AtenLinearOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenLinearOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
MLIRContext *context = op->getContext();
Location loc = op->getLoc();
Value input = adaptor.input();
Value weight = adaptor.weight();
Value bias = adaptor.bias();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
auto inputType = input.getType().cast<RankedTensorType>();
auto weightType = weight.getType().cast<RankedTensorType>();

if (inputType.getRank() != 2 && inputType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or rank 3");
}

if (!bias.getType().isa<Torch::NoneType>()) {
auto biasType = bias.getType().cast<RankedTensorType>();
// Only handle the case of rank 2 `weight` for now.
// TODO: Insert the appropriate reshape to collapse any leading dimensions.
if (weightType.getRank() != 2 || biasType.getRank() != 1) {
return rewriter.notifyMatchFailure(
op, "expected weight to be rank 2 and bias to be rank 1");
}
// TODO: Handle type promotion. What are ATen's promotion rules?
if (inputType.getElementType() != weightType.getElementType() ||
inputType.getElementType() != biasType.getElementType()) {
return rewriter.notifyMatchFailure(op, "unimplemented: type promotion");
}
// TODO: We can handle a static size 1 here at some complexity cost, but the
// dynamic case is not representable in linalg. We don't handle either for
// now. Biases are generally statically shaped for most models (since for
// inference they are constants, and for training they don't change shape
// typically), so this is not too constraining.
auto biasSize = bias.getType().cast<RankedTensorType>().getShape()[0];
if (biasSize == 1 || biasSize == ShapedType::kDynamicSize)
return rewriter.notifyMatchFailure(
op, "unimplemented: size-1 broadcasting for aten::LinearOp");
}


Value batchDim = nullptr;
int restDim = 0;
if (inputType.getRank() == 3) {
batchDim = getDimOp(rewriter, loc, input, 0);
restDim = 1;
}

Value inputDim0 = getDimOp(rewriter, loc, input, restDim + 0);
Value inputDim1 = getDimOp(rewriter, loc, input, restDim + 1);
Value weightDim0 = getDimOp(rewriter, loc, weight, 0);
Value weightDim1 = getDimOp(rewriter, loc, weight, 1);
Value contractingDimEqual = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, inputDim1, weightDim1);
rewriter.create<cf::AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr(
"mismatching contracting dimension for aten.linear"));

if (!bias.getType().isa<Torch::NoneType>()) {
Value biasDim0 = getDimOp(rewriter, loc, bias, 0);
// Here we take advantage of ruling out the size-1 case above.
// In the static-size-1 case, we will not emit this check at all.
Value biasSizeCorrect = rewriter.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::eq, weightDim0, biasDim0);
rewriter.create<cf::AssertOp>(
loc, biasSizeCorrect,
rewriter.getStringAttr("mismatching bias size for aten.linear"));
}

Value initTensor;
SmallVector<AffineMap> broadcastIndexingMaps;
Value transposedWeightInitTensor;
if (inputType.getRank() > 2) {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{batchDim, weightDim1, weightDim0},
weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
} else {
initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{inputDim0, weightDim0},
inputType.getElementType());
transposedWeightInitTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{weightDim1, weightDim0}, weightType.getElementType());
broadcastIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1)}, context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
}

SmallVector<StringRef> iteratorTypes(inputType.getRank(), "parallel");
Value broadcasted;
if (!bias.getType().isa<Torch::NoneType>()) {
broadcasted =
rewriter
.create<linalg::GenericOp>(
loc, initTensor.getType(), bias, initTensor,
/*indexingMaps=*/broadcastIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
} else {
Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
Value c0float = rewriter.create<arith::ConstantOp>(
loc, FloatAttr::get(elementType, 0.0));
broadcasted = rewriter.create<linalg::FillOp>(loc, c0float, initTensor)
.getResult(0);
}
// We need a matmul with dimension ordering (N, K) * (M, K), so transpose
// the weights to fit into linalg::MatmulOp which is (N, K) * (K, M).
// TODO: This whole aten.linear lowering should eventually be generated from
// a single linalg ODS generator statement. Both the bias and matmul part.
SmallVector<AffineMap> transposeIndexingMaps = {
AffineMap::get(
/*dimCount=*/inputType.getRank(), /*symbolCount=*/0,
{rewriter.getAffineDimExpr(1 + restDim),
rewriter.getAffineDimExpr(0 + restDim)},
context),
rewriter.getMultiDimIdentityMap(inputType.getRank())};
Value transposedWeights =
rewriter
.create<linalg::GenericOp>(
loc, transposedWeightInitTensor.getType(), weight,
transposedWeightInitTensor,
/*indexingMaps=*/transposeIndexingMaps,
/*iteratorTypes=*/iteratorTypes,
[](OpBuilder &b, Location loc, ValueRange args) {
b.create<linalg::YieldOp>(loc, args[0]);
})
.getResult(0);
Value matmul;
if (batchDim)
matmul = rewriter
.create<linalg::BatchMatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);
else
matmul = rewriter
.create<linalg::MatmulOp>(
loc, broadcasted.getType(),
ValueRange{input, transposedWeights}, broadcasted)
.getResult(0);

Type newResultType = getTypeConverter()->convertType(op.getType());
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, matmul);
return success();
}
};
} // namespace

namespace {
class ConvertAtenConvolutionOp : public OpConversionPattern<AtenConvolutionOp> {
public:
Expand Down Expand Up @@ -768,8 +613,6 @@ void mlir::torch::torch_to_linalg::populateLinearPatternsAndLegality(
patterns.add<ConvertAtenMatmulOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenConvolutionOp>();
patterns.add<ConvertAtenConvolutionOp>(typeConverter, context);
}
53 changes: 53 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,57 @@ class DecomposeAtenFullOp : public OpRewritePattern<AtenFullOp> {
};
} // namespace

namespace {
// Decompose `aten.linear` op into `aten.matmul` and `aten.add` ops.
class DecomposeAtenLinearOp : public OpRewritePattern<AtenLinearOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenLinearOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
Value input = op.input();
Value weight = op.weight();
Value bias = op.bias();

BaseTensorType inputType = input.getType().cast<BaseTensorType>();
if (!inputType.hasSizes() || inputType.getSizes().size() < 2)
return rewriter.notifyMatchFailure(
op, "expected input to be rank 2 or greater");

BaseTensorType weightType = weight.getType().cast<BaseTensorType>();
// weight must be a rank 2 matrix.
if (!weightType.hasSizes() || weightType.getSizes().size() != 2)
return rewriter.notifyMatchFailure(op, "expected weight to be a rank 2");

ArrayRef<int64_t> wShape = weightType.getSizes();
SmallVector<int64_t> transposeShape;
transposeShape.push_back(wShape[1]);
transposeShape.push_back(wShape[0]);
Type transposeType = weightType.getWithSizesAndDtype(
llvm::makeArrayRef(transposeShape), weightType.getDtype());
Value transposeWeight =
rewriter.create<AtenTOp>(loc, transposeType, weight);

Value matmul = rewriter.create<AtenMatmulOp>(loc, op.getType(), input,
transposeWeight);
if (bias.getType().isa<Torch::NoneType>()) {
rewriter.replaceOp(op, matmul);
return success();
}

BaseTensorType biasType = bias.getType().cast<BaseTensorType>();
if (!biasType.hasSizes() || biasType.getSizes().size() != 1)
return rewriter.notifyMatchFailure(op, "expected bias to be rank 1");

Value alpha =
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1));
rewriter.replaceOpWithNewOp<AtenAddTensorOp>(op, op.getType(), matmul,
op.bias(), alpha);
return success();
}
};
} // namespace

namespace {
// Decompose `aten.full_like` op into `aten.empty_like` and `aten.fill` ops.
class DecomposeAtenFullLikeOp : public OpRewritePattern<AtenFullLikeOp> {
Expand Down Expand Up @@ -2356,6 +2407,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<AtenHardtanhOp>();
patterns.add<DecomposeAtenFullOp>(context);
target.addIllegalOp<AtenFullOp>();
patterns.add<DecomposeAtenLinearOp>(context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<DecomposeAtenFullLikeOp>(context);
target.addIllegalOp<AtenFullLikeOp>();
patterns.add<DecomposeAtenIndexPutOp>(context);
Expand Down

0 comments on commit c79cc85

Please sign in to comment.