Skip to content

Commit

Permalink
refactor forward_graph_extract_pass (PaddlePaddle#104)
Browse files Browse the repository at this point in the history
  • Loading branch information
gglin001 authored Aug 30, 2021
1 parent 9fb5f68 commit 8ceadf0
Showing 1 changed file with 24 additions and 65 deletions.
89 changes: 24 additions & 65 deletions paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,6 @@

#include "paddle/fluid/framework/ir/ipu/forward_graph_extract_pass.h"

#include <glog/logging.h>

#include <algorithm>
#include <array>
#include <fstream>
#include <iosfwd>
#include <memory>
#include <ostream>
#include <string>
#include <unordered_map>
#include <unordered_set>

#include "paddle/fluid/framework/ir/pass_tester_helper.h"

namespace paddle {
Expand Down Expand Up @@ -54,81 +42,68 @@ void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
{OpRole::kOptimize, {}}, {OpRole::kRPC, {}},
{OpRole::kDist, {}}, {OpRole::kLRSched, {}},
{OpRole::kLoss, {}}, {OpRole::kNotSpecified, {}}};
std::unordered_map<OpRole, std::unordered_set<std::string>> all_ops_name{
{OpRole::kForward, {}}, {OpRole::kBackward, {}},
{OpRole::kOptimize, {}}, {OpRole::kRPC, {}},
{OpRole::kDist, {}}, {OpRole::kLRSched, {}},
{OpRole::kLoss, {}}, {OpRole::kNotSpecified, {}}};

for (auto* node : graph->Nodes()) {
if (!node->IsOp()) {
continue;
}
auto op_role = BOOST_GET_MUTABLE(int, node->Op()->GetAttr("op_role"));
if (op_role == static_cast<int>(OpRole::kForward)) {
all_ops[OpRole::kForward].insert(node);
all_ops_name[OpRole::kForward].insert(node->Name());
} else if (op_role == static_cast<int>(OpRole::kBackward)) {
all_ops[OpRole::kBackward].insert(node);
} else if (op_role == static_cast<int>(OpRole::kOptimize)) {
all_ops[OpRole::kOptimize].insert(node);
all_ops_name[OpRole::kOptimize].insert(node->Name());
} else if (op_role == static_cast<int>(OpRole::kRPC)) {
} else if (op_role == static_cast<int>(OpRole::kDist)) {
} else if (op_role == static_cast<int>(OpRole::kLRSched)) {
} else if (op_role == static_cast<int>(OpRole::kLoss)) {
all_ops[OpRole::kLoss].insert(node);
all_ops_name[OpRole::kLoss].insert(node->Name());
} else if (op_role == static_cast<int>(OpRole::kNotSpecified)) {
LOG(WARNING) << "Op: " << node->Name() << " OpRole is NotSpecified ";
}
}

std::unordered_set<std::string> forward_var_names;
std::unordered_set<ir::Node*> forward_vars;
std::unordered_set<ir::Node*> backward_vars;
// std::unordered_set<ir::Node*> forward_vars;

std::unordered_set<ir::Node*> control_vars;
// forward_vars
for (auto& nodes : std::array<std::unordered_set<ir::Node*>, 2>{
all_ops[OpRole::kForward], all_ops[OpRole::kLoss],
// all_ops[OpRole::kOptimize],
}) {
all_ops[OpRole::kForward], all_ops[OpRole::kLoss]}) {
for (auto* node : nodes) {
for (auto& name_map : node->Op()->Inputs()) {
for (auto& name : name_map.second) {
forward_var_names.insert(name);
}
for (auto* in_node : node->inputs) {
forward_vars.insert(in_node);
}
for (auto& name_map : node->Op()->Outputs()) {
for (auto& name : name_map.second) {
forward_var_names.insert(name);
}
for (auto* out_node : node->outputs) {
forward_vars.insert(out_node);
}
}
}

auto not_contains = [&](const std::string& name,
std::unordered_set<std::string>& names) {
return names.find(name) == names.end();
};

// control_vars & backward_vars
for (auto* node : graph->Nodes()) {
if (!node->IsVar()) {
continue;
}
if (node->IsCtrlVar()) {
control_vars.insert(node);
}
for (auto* in_node : node->inputs) {
if (!not_contains(in_node->Name(), all_ops_name[OpRole::kOptimize])) {
if (all_ops[OpRole::kOptimize].count(in_node)) {
backward_vars.insert(node);
}
}
}

// all removed node
std::unordered_set<ir::Node*> rm_nodes;
for (auto* node : graph->Nodes()) {
if (backward_vars.find(node) != backward_vars.end()) {
if (backward_vars.count(node)) {
rm_nodes.insert(node);
} else if (control_vars.count(node)) {
rm_nodes.insert(node);
} else if (not_contains(node->Name(), all_ops_name[OpRole::kForward]) &&
not_contains(node->Name(), all_ops_name[OpRole::kLoss]) &&
// not_contains(node->Name(), all_ops_name[OpRole::kOptimize]) &&
not_contains(node->Name(), forward_var_names)) {
} else if (all_ops[OpRole::kBackward].count(node)) {
rm_nodes.insert(node);
} else if (all_ops[OpRole::kForward].count(node) == 0 &&
all_ops[OpRole::kLoss].count(node) == 0 &&
forward_vars.count(node) == 0) {
rm_nodes.insert(node);
}
}
Expand All @@ -152,26 +127,10 @@ void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
}
}
}

VLOG(10) << "\t" << node->Name();
graph->RemoveNode(node);
}

// TODO(alleng, zhixin) refactor this part
std::unordered_set<ir::Node*> rm_nodes_backward;
for (auto node : graph->Nodes()) {
if (!node->IsOp()) {
continue;
}
auto op_role = BOOST_GET_MUTABLE(int, node->Op()->GetAttr("op_role"));
if (op_role == 1) {
rm_nodes_backward.insert(node);
}
}
for (auto* node : rm_nodes_backward) {
graph->RemoveNode(node);
}

VLOG(10) << "Post Graph: ";
VLOG(10) << DebugString(graph);

Expand All @@ -191,4 +150,4 @@ void ForwardGraphExtractPass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(forward_graph_extract_pass,
paddle::framework::ir::ForwardGraphExtractPass);

USE_PASS(graph_viz_pass);
// USE_PASS(graph_viz_pass);

0 comments on commit 8ceadf0

Please sign in to comment.