Skip to content

Commit

Permalink
feat(aten::dropout_): Remove inplace dropout
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Jun 25, 2020
1 parent 19c91f2 commit 7aa57c3
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
3 changes: 2 additions & 1 deletion core/conversion/conversion_blacklist.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ const std::unordered_set<std::string>& get_non_convertable_nodes() {
"prim::GetAttr",
"prim::CallMethod",
"prim::Drop",
"aten:dropout",
"aten::dropout",
"aten::dropout_"
};
return nonconvertable_nodes;
}
Expand Down
14 changes: 14 additions & 0 deletions core/lowering/passes/remove_dropout.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,20 @@ void RemoveDropout(std::shared_ptr<torch::jit::Graph>& graph) {
remove_dropout.RegisterRewritePattern(
dropout_pattern, no_dropout_pattern);
remove_dropout.runOnGraph(graph);

std::string dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
%6 = aten::dropout_(%input, %4, %5)
return (%6))IR";
std::string no_dropout_inplace_pattern = R"IR(
graph(%input, %4, %5):
return (%input))IR";

torch::jit::SubgraphRewriter remove_dropout_inplace_pattern;
remove_dropout_inplace_pattern.RegisterRewritePattern(
dropout_inplace_pattern, no_dropout_inplace_pattern);
remove_dropout_inplace_pattern.runOnGraph(graph);

LOG_GRAPH("Post remove dropout: " << *graph);
}

Expand Down

0 comments on commit 7aa57c3

Please sign in to comment.