From 2fc612da6463170f8e55d9509f3e2751a713593a Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 3 Sep 2021 14:47:37 -0700 Subject: [PATCH] fix(//core/lowering): Fixes module level fallback recursion This commit fixes module level fallback by using method calls to determine modules to recurse down too. This should be robust to names other than forward used for methods as well as ignoring functional modules. Signed-off-by: Naren Dasan Signed-off-by: Naren Dasan --- core/lowering/lowering.cpp | 21 ++++++++-------- core/lowering/passes/module_fallback.cpp | 31 ++++++++++++++++++++---- 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/core/lowering/lowering.cpp b/core/lowering/lowering.cpp index 506a4934fe..4be3e403aa 100644 --- a/core/lowering/lowering.cpp +++ b/core/lowering/lowering.cpp @@ -37,7 +37,9 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { torch::jit::EliminateCommonSubexpression(g); } torch::jit::EliminateDeadCode(g); - passes::MarkNodesForFallback(g, true); + if (lower_info.forced_fallback_modules.size() > 0) { + passes::MarkNodesForFallback(g, true); + } passes::UnpackHardSwish(g); passes::EliminateExceptionOrPassPattern(g); passes::ReduceToOperation(g); @@ -60,12 +62,13 @@ void LowerGraph(std::shared_ptr& g, LowerInfo lower_info) { LOG_GRAPH(*g); } -torch::jit::Module LowerModule( - const torch::jit::Module& mod, - std::string method_name, - std::unordered_set forced_fallback_modules) { - passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); - LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph()); +torch::jit::Module LowerModule(const torch::jit::Module& mod, std::string method_name, const LowerInfo& lower_info) { + std::unordered_set forced_fallback_modules( + lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); + if (forced_fallback_modules.size() > 0) { + passes::NotateModuleForFallback(mod, "", method_name, forced_fallback_modules); + LOG_GRAPH("After MLF notation pass: " << *mod.get_method(method_name).graph()); + } auto mod_ = torch::jit::freeze_module(mod); LOG_GRAPH("After freeze: " << *mod_.get_method(method_name).graph()); return mod_; @@ -77,9 +80,7 @@ std::pair, std::vector> L const LowerInfo& lower_info) { LOG_DEBUG(lower_info); LOG_GRAPH("Before lowering: " << *mod.get_method(method_name).graph()); - std::unordered_set forced_fallback_modules( - lower_info.forced_fallback_modules.begin(), lower_info.forced_fallback_modules.end()); - auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, forced_fallback_modules); + auto lowered_mod = lower_info.unfreeze_module ? mod : LowerModule(mod, method_name, lower_info); auto g = lowered_mod.get_method(method_name).graph(); LOG_GRAPH("LibTorch Lowering"); diff --git a/core/lowering/passes/module_fallback.cpp b/core/lowering/passes/module_fallback.cpp index 99b2578dec..c2bdc9739a 100644 --- a/core/lowering/passes/module_fallback.cpp +++ b/core/lowering/passes/module_fallback.cpp @@ -61,8 +61,29 @@ void NotateModuleForFallback( LOG_GRAPH("Notated graph: " << *g); } - for (const auto sub_mod : mod.named_children()) { - NotateModuleForFallback(sub_mod.value, sub_mod.name, method_name, forced_fallback_modules); + if (mod.named_children().size() > 0) { + for (const auto n : nodes) { + std::string sub_method_name = ""; + if (n->kind() == torch::jit::prim::CallMethod) { + sub_method_name = n->s(c10::Symbol::attr("name")); + auto sub_mod_val = n->input(0); + auto sub_mod_src_n = sub_mod_val->node(); + if (!sub_mod_src_n->hasAttributeS("name")) { + LOG_GRAPH("Node: " << util::node_info(sub_mod_src_n) << " manages a module with no name, skipping"); + break; + } + auto sub_mod_name = sub_mod_src_n->s(c10::Symbol::attr("name")); + for (const auto sub_mod : mod.named_children()) { + // Theres probably a way to directly access the module we care about + if (sub_mod.name == sub_mod_name) { + LOG_GRAPH( + "Looking at .() next: " << sub_mod_name << "." << sub_method_name + << "() (lowering.passes.NotateModuleForFallback)"); + NotateModuleForFallback(sub_mod.value, sub_mod.name, sub_method_name, forced_fallback_modules); + } + } + } + } } } @@ -74,7 +95,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del auto n = *it; if (!mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "start") { - LOG_DEBUG("Starting to mark new segmented block targeted for torch"); + LOG_GRAPH("Starting to mark new segmented block targeted for torch"); mark.push(true); if (delete_delims) { it.destroyCurrent(); @@ -82,7 +103,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } else if (mark.top() && n->kind() == torch::jit::prim::Enter && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "start") { - LOG_DEBUG("Found the start of another segmented block targeted for torch while actively marking a block"); + LOG_GRAPH("Found the start of another segmented block targeted for torch while actively marking a block"); mark.push(true); if (delete_delims) { it.destroyCurrent(); @@ -90,7 +111,7 @@ void MarkNodesForFallback(std::shared_ptr& g, bool delete_del } } else if (mark.top() && n->kind() == torch::jit::prim::Exit && n->hasAttributeS("compilation_edge")) { if (n->s(c10::Symbol::attr("compilation_edge")) == "end") { - LOG_DEBUG("Found the end of segmented block targeted for torch while actively marking a block"); + LOG_GRAPH("Found the end of segmented block targeted for torch while actively marking a block"); mark.pop(); if (delete_delims) { it.destroyCurrent();