Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Add native_dropout_backward & native_layer_norm_backward decomposition (
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok authored and bladedisc committed Aug 31, 2022
1 parent b60ac1c commit 3f695e2
Showing 1 changed file with 123 additions and 1 deletion.
124 changes: 123 additions & 1 deletion lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1293,6 +1293,25 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
};
} // namespace

// grad_output * mask * scale
namespace {
class DecomposeAtenNativeDropoutBackwardOp
: public OpRewritePattern<AtenNativeDropoutBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value maskedGradOutput = rewriter.create<AtenMulTensorOp>(
loc, op.getType(), op.grad_output(), op.mask());
rewriter.replaceOpWithNewOp<AtenMulScalarOp>(op, op.getType(),
maskedGradOutput, op.scale());
return success();
}
};
} // namespace

// Decompose aten.var into: aten.var.dim op.
namespace {
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
Expand Down Expand Up @@ -1681,6 +1700,105 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern<AtenLayerNormOp> {
};
} // namespace

namespace {
class DecomposeAtenNativeLayerNormBackwardOp
: public OpRewritePattern<AtenNativeLayerNormBackwardOp> {
using OpRewritePattern<AtenNativeLayerNormBackwardOp>::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeLayerNormBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();
auto context = op.getContext();

auto inputTy = op.input().getType().cast<BaseTensorType>();
if (!inputTy.hasSizes())
return rewriter.notifyMatchFailure(
op, "input tensor should have known sizes.");
int64_t inputRank = inputTy.getSizes().size();
Value normalizedShape = op.normalized_shape();
SmallVector<Value> normalizedShapeSizesTorchInt;
getListConstructElements(normalizedShape, normalizedShapeSizesTorchInt);
int64_t axis = inputRank - normalizedShapeSizesTorchInt.size();
SmallVector<int64_t> reduceDimInts(normalizedShapeSizesTorchInt.size());
SmallVector<int64_t> outerDimInts(axis);
std::iota(reduceDimInts.begin(), reduceDimInts.end(), axis);
std::iota(outerDimInts.begin(), outerDimInts.end(), 0);
auto reducedTy = op.getResult(1).getType();
auto sizeListType = ListType::get(IntType::get(context));

auto fromIntsToList = [&](ArrayRef<int64_t> dimInts) -> Value {
SmallVector<Value> dimVals;
dimVals.reserve(dimInts.size());
std::transform(dimInts.begin(), dimInts.end(),
std::back_inserter(dimVals), [&](int64_t d) {
return rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(d));
});
Value dimList =
rewriter.create<PrimListConstructOp>(loc, sizeListType, dimVals);
return dimList;
};
// build reduce & outer dims
auto reduceDimList = fromIntsToList(reduceDimInts);
auto outerDimList = fromIntsToList(outerDimInts);
Value one = rewriter.create<Torch::ConstantIntOp>(
loc, rewriter.getI64IntegerAttr(1));

Value cstFalse = rewriter.create<Torch::ConstantBoolOp>(loc, false);
Value none = rewriter.create<Torch::ConstantNoneOp>(loc);

// x_hat
Value inputSubMean = rewriter.create<AtenSubTensorOp>(
loc, inputTy, op.input(), op.mean(), one);
Value xHat =
rewriter.create<AtenMulTensorOp>(loc, inputTy, inputSubMean, op.rstd());

// grad(x_hat)
Value xHatGrad = op.grad_out();
Value weight = op.weight();
Value wGrad = none;
if (!weight.getType().isa<Torch::NoneType>()) {
xHatGrad = rewriter.create<AtenMulTensorOp>(loc, xHatGrad.getType(),
xHatGrad, weight);
wGrad = rewriter.create<AtenSumDimIntListOp>(
loc, weight.getType(),
rewriter.create<AtenMulTensorOp>(loc, inputTy, op.grad_out(), xHat),
outerDimList, cstFalse, none);
}
Value bias = op.bias();
Value bGrad = none;
if (!bias.getType().isa<Torch::NoneType>()) {
bGrad = rewriter.create<AtenSumDimIntListOp>(
loc, bias.getType(), op.grad_out(), outerDimList, cstFalse, none);
}

Value cstTrue = rewriter.create<Torch::ConstantBoolOp>(loc, true);
// grad(mean)
Value meanGrad = rewriter.create<AtenMeanDimOp>(
loc, op.mean().getType(), xHatGrad, reduceDimList, cstTrue, none);
// grad(rstd)
Value xHatGradMulXHat =
rewriter.create<AtenMulTensorOp>(loc, inputTy, xHatGrad, xHat);
Value rstdGrad0 = rewriter.create<AtenMeanDimOp>(
loc, op.rstd().getType(), xHatGradMulXHat, reduceDimList, cstTrue,
none);
Value rstdGrad1 =
rewriter.create<AtenMulTensorOp>(loc, inputTy, xHat, rstdGrad0);

// grad(input)
Value inner =
rewriter.create<AtenSubTensorOp>(loc, inputTy, xHatGrad, meanGrad, one);
inner =
rewriter.create<AtenSubTensorOp>(loc, inputTy, inner, rstdGrad1, one);
Value gradInput =
rewriter.create<AtenMulTensorOp>(loc, inputTy, op.rstd(), inner);

rewriter.replaceOp(op, {gradInput, wGrad, bGrad});

return success();
}
};
} // namespace

namespace {
class DecomposeAtenNativeLayerNormOp
: public OpRewritePattern<AtenNativeLayerNormOp> {
Expand Down Expand Up @@ -2771,7 +2889,9 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);

target.addIllegalOp<AtenNativeLayerNormBackwardOp>();
patterns.add<DecomposeAtenNativeLayerNormBackwardOp>(context);

target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
target.addIllegalOp<AtenConvolutionOverrideableOp>();
Expand Down Expand Up @@ -2836,6 +2956,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_ToCopyOp>();
patterns.add<DecomposeAtenDropoutOp>(context);
target.addIllegalOp<AtenDropoutOp>();
patterns.add<DecomposeAtenNativeDropoutBackwardOp>(context);
target.addIllegalOp<AtenNativeDropoutBackwardOp>();
target.addIllegalOp<AtenNewEmptyOp>();
patterns.add<DecomposeAtenNewEmptyOp>(context);
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
Expand Down

0 comments on commit 3f695e2

Please sign in to comment.