From 7aa57c3d49cd169b7a620a2bc0dd78573cdf25a7 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Wed, 24 Jun 2020 17:41:50 -0700 Subject: [PATCH] feat(aten::dropout_): Remove inplace dropout Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/conversion/conversion_blacklist.cpp | 3 ++- core/lowering/passes/remove_dropout.cpp | 14 ++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/core/conversion/conversion_blacklist.cpp b/core/conversion/conversion_blacklist.cpp index 01b49c469b..8d26c38384 100644 --- a/core/conversion/conversion_blacklist.cpp +++ b/core/conversion/conversion_blacklist.cpp @@ -22,7 +22,8 @@ const std::unordered_set& get_non_convertable_nodes() { "prim::GetAttr", "prim::CallMethod", "prim::Drop", - "aten:dropout", + "aten::dropout", + "aten::dropout_" }; return nonconvertable_nodes; } diff --git a/core/lowering/passes/remove_dropout.cpp b/core/lowering/passes/remove_dropout.cpp index b5776f2bfa..7c47df909a 100644 --- a/core/lowering/passes/remove_dropout.cpp +++ b/core/lowering/passes/remove_dropout.cpp @@ -20,6 +20,20 @@ void RemoveDropout(std::shared_ptr& graph) { remove_dropout.RegisterRewritePattern( dropout_pattern, no_dropout_pattern); remove_dropout.runOnGraph(graph); + + std::string dropout_inplace_pattern = R"IR( + graph(%input, %4, %5): + %6 = aten::dropout_(%input, %4, %5) + return (%6))IR"; + std::string no_dropout_inplace_pattern = R"IR( + graph(%input, %4, %5): + return (%input))IR"; + + torch::jit::SubgraphRewriter remove_dropout_inplace_pattern; + remove_dropout_inplace_pattern.RegisterRewritePattern( + dropout_inplace_pattern, no_dropout_inplace_pattern); + remove_dropout_inplace_pattern.runOnGraph(graph); + LOG_GRAPH("Post remove dropout: " << *graph); }