Skip to content
This repository has been archived by the owner on Jun 19, 2024. It is now read-only.

Commit

Permalink
Add native_dropout_backward & native_layer_norm_backward decomposition (
Browse files Browse the repository at this point in the history
  • Loading branch information
Tanyo Kwok authored and bladedisc committed Sep 23, 2022
1 parent 83c6659 commit d2b4c40
Showing 1 changed file with 23 additions and 0 deletions.
23 changes: 23 additions & 0 deletions lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1309,6 +1309,25 @@ class DecomposeAtenDropoutOp : public OpRewritePattern<AtenDropoutOp> {
};
} // namespace

// grad_output * mask * scale
namespace {
class DecomposeAtenNativeDropoutBackwardOp
: public OpRewritePattern<AtenNativeDropoutBackwardOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenNativeDropoutBackwardOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value maskedGradOutput = rewriter.create<AtenMulTensorOp>(
loc, op.getType(), op.grad_output(), op.mask());
rewriter.replaceOpWithNewOp<AtenMulScalarOp>(op, op.getType(),
maskedGradOutput, op.scale());
return success();
}
};
} // namespace

// Decompose aten.var into: aten.var.dim op.
namespace {
class DecomposeAtenVarOp : public OpRewritePattern<AtenVarOp> {
Expand Down Expand Up @@ -2973,6 +2992,8 @@ class DecomposeComplexOpsPass
patterns.add<DecomposeAtenLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormOp>();
patterns.add<DecomposeAtenNativeLayerNormOp>(context);
target.addIllegalOp<AtenNativeLayerNormBackwardOp>();
patterns.add<DecomposeAtenNativeLayerNormBackwardOp>(context);

target.addIllegalOp<AtenNativeBatchNormOp>();
patterns.add<DecomposeAtenNativeBatchNormOp>(context);
Expand Down Expand Up @@ -3042,6 +3063,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<Aten_ToCopyOp>();
patterns.add<DecomposeAtenDropoutOp>(context);
target.addIllegalOp<AtenDropoutOp>();
patterns.add<DecomposeAtenNativeDropoutBackwardOp>(context);
target.addIllegalOp<AtenNativeDropoutBackwardOp>();
target.addIllegalOp<AtenNewEmptyOp>();
patterns.add<DecomposeAtenNewEmptyOp>(context);
patterns.add<DecomposeAtenIndexPutHackedTwinOp>(context);
Expand Down

0 comments on commit d2b4c40

Please sign in to comment.