Skip to content

Commit

Permalink
SSA Graph Builder Factory
Browse files Browse the repository at this point in the history
* Use Builder Chain to decorate new builders. It is easy to extend
  builders.
* Make graphviz path as a build strategy, not a FLAGS.
  • Loading branch information
reyoung committed Jun 6, 2018
1 parent 106ee9d commit d9af153
Show file tree
Hide file tree
Showing 15 changed files with 295 additions and 89 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/framework/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method)


cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS graph_builder_factory threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)

cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/details/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ cc_library(rpc_op_handle SRCS rpc_op_handle.cc DEPS framework_proto scope place

cc_library(ssa_graph SRCS ssa_graph.cc DEPS var_handle op_handle_base)
cc_library(ssa_graph_builder SRCS ssa_graph_builder.cc DEPS ssa_graph)
cc_library(ssa_graph_printer SRCS ssa_graph_printer.cc DEPS ssa_graph_builder)

cc_library(variable_visitor SRCS variable_visitor.cc DEPS lod_tensor selected_rows)

Expand All @@ -28,6 +29,9 @@ cc_library(gather_op_handle SRCS gather_op_handle.cc DEPS op_handle_base scope d
cc_library(multi_devices_graph_builder SRCS multi_devices_graph_builder.cc DEPS ssa_graph_builder computation_op_handle
scale_loss_grad_op_handle rpc_op_handle ${multi_devices_graph_builder_deps} reduce_op_handle broadcast_op_handle)


cc_library(graph_builder_factory SRCS graph_builder_factory.cc DEPS multi_devices_graph_builder ssa_graph_printer)

cc_library(ssa_graph_executor SRCS ssa_graph_executor.cc DEPS ssa_graph framework_proto)
cc_library(threaded_ssa_graph_executor SRCS threaded_ssa_graph_executor.cc DEPS fetch_op_handle ssa_graph_executor scope
simple_threadpool device_context)
Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/broadcast_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ struct BroadcastOpHandle : public OpHandleBase {
void RunImpl() override;

private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
#endif
Expand Down
4 changes: 4 additions & 0 deletions paddle/fluid/framework/details/build_strategy.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

#pragma once

#include <string>

namespace paddle {
namespace framework {
namespace details {
Expand All @@ -29,6 +31,8 @@ struct BuildStrategy {

ReduceStrategy reduce_{ReduceStrategy::kAllReduce};
GradientScaleStrategy gradient_scale_{GradientScaleStrategy::kCoeffNumDevice};

std::string debug_graphviz_path_{""};
};

} // namespace details
Expand Down
47 changes: 47 additions & 0 deletions paddle/fluid/framework/details/graph_builder_factory.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/graph_builder_factory.h"
#include <fstream>
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/ssa_graph_printer.h"

namespace paddle {
namespace framework {
namespace details {
std::unique_ptr<SSAGraphBuilder> SSAGraphBuilderFactory::Create() {
std::unique_ptr<SSAGraphBuilder> res(
#ifdef PADDLE_WITH_CUDA
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, nccl_ctxs_, strategy_)
#else
new MultiDevSSAGraphBuilder(places_, loss_var_name_, param_names_,
local_scopes_, strategy_)
#endif
); // NOLINT

if (!strategy_.debug_graphviz_path_.empty()) {
std::unique_ptr<std::ostream> fout(
new std::ofstream(strategy_.debug_graphviz_path_));
PADDLE_ENFORCE(fout->good());
std::unique_ptr<GraphvizSSAGraphPrinter> graphviz_printer(
new GraphvizSSAGraphPrinter());
res.reset(new SSAGraghBuilderWithPrinter(
std::move(fout), std::move(graphviz_printer), std::move(res)));
}
return res;
}
} // namespace details
} // namespace framework
} // namespace paddle
67 changes: 67 additions & 0 deletions paddle/fluid/framework/details/graph_builder_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_builder.h"
#include "paddle/fluid/platform/place.h"

#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/nccl_helper.h"
#endif

namespace paddle {
namespace framework {
class Scope;
namespace details {

class SSAGraphBuilderFactory {
public:
SSAGraphBuilderFactory(const std::vector<platform::Place>& places,
const std::string& loss_var_name,
const std::unordered_set<std::string>& param_names,
const std::vector<Scope*>& local_scopes,
const BuildStrategy& strategy)
: places_(places),
loss_var_name_(loss_var_name),
param_names_(param_names),
local_scopes_(local_scopes),
strategy_(strategy) {}

#ifdef PADDLE_WITH_CUDA
void SetNCCLContextMap(platform::NCCLContextMap* nccl_ctxs) {
nccl_ctxs_ = nccl_ctxs;
}
#endif

std::unique_ptr<SSAGraphBuilder> Create();

private:
std::vector<platform::Place> places_;
std::string loss_var_name_;
std::unordered_set<std::string> param_names_;
std::vector<Scope*> local_scopes_;
BuildStrategy strategy_;

#ifdef PADDLE_WITH_CUDA
platform::NCCLContextMap* nccl_ctxs_;
#endif
};

} // namespace details
} // namespace framework
} // namespace paddle
9 changes: 0 additions & 9 deletions paddle/fluid/framework/details/multi_devices_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,6 @@
#include "paddle/fluid/framework/details/nccl_all_reduce_op_handle.h"
#endif

DEFINE_string(ssa_graph_path, "/tmp/ssa_graph.dot",
"the ssa graph path only print with GLOG_v=10,"
"default /tmp/graph.dot");

namespace paddle {
namespace framework {
namespace details {
Expand Down Expand Up @@ -277,11 +273,6 @@ std::unique_ptr<SSAGraph> MultiDevSSAGraphBuilder::Build(
*/
AddOutputToLeafOps(&result);

if (VLOG_IS_ON(10)) {
std::ofstream fout(FLAGS_ssa_graph_path);
PrintGraphviz(*graph, fout);
}

return std::unique_ptr<SSAGraph>(graph);
}

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/nccl_all_reduce_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ struct NCCLAllReduceOpHandle : public OpHandleBase {
void RunImpl() override;

private:
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;
const platform::NCCLContextMap &nccl_ctxs_;
};

Expand Down
4 changes: 2 additions & 2 deletions paddle/fluid/framework/details/reduce_op_handle.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ namespace framework {
namespace details {

struct ReduceOpHandle : public OpHandleBase {
const std::vector<Scope *> &local_scopes_;
const std::vector<platform::Place> &places_;
std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_;

#ifdef PADDLE_WITH_CUDA
const platform::NCCLContextMap *nccl_ctxs_;
Expand Down
58 changes: 0 additions & 58 deletions paddle/fluid/framework/details/ssa_graph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,64 +73,6 @@ void SSAGraphBuilder::CreateOpOutput(SSAGraph *graph, OpHandleBase *op_handle,
op_handle->AddOutput(var);
}

template <typename Callback>
void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
}
}
}

for (auto &var : graph.dep_vars_) {
callback(*var);
}
}

void SSAGraphBuilder::PrintGraphviz(const SSAGraph &graph, std::ostream &sout) {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;

sout << "digraph G {\n";

IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);

size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;

if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});

size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}

for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}

sout << "}\n";
}

void SSAGraphBuilder::AddOutputToLeafOps(SSAGraph *graph) {
for (auto &op : graph->ops_) {
if (!op->Outputs().empty()) {
Expand Down
2 changes: 0 additions & 2 deletions paddle/fluid/framework/details/ssa_graph_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,6 @@ class SSAGraphBuilder {
const platform::Place &place, size_t place_offset);

static void AddOutputToLeafOps(SSAGraph *graph);

static void PrintGraphviz(const SSAGraph &graph, std::ostream &sout);
};
} // namespace details
} // namespace framework
Expand Down
83 changes: 83 additions & 0 deletions paddle/fluid/framework/details/ssa_graph_printer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/details/ssa_graph_printer.h"
#include <string>
#include "paddle/fluid/framework/details/ssa_graph.h"

namespace paddle {
namespace framework {
namespace details {

template <typename Callback>
static inline void IterAllVar(const SSAGraph &graph, Callback callback) {
for (auto &each : graph.vars_) {
for (auto &pair1 : each) {
for (auto &pair2 : pair1.second) {
callback(*pair2);
}
}
}

for (auto &var : graph.dep_vars_) {
callback(*var);
}
}

void GraphvizSSAGraphPrinter::Print(const SSAGraph &graph,
std::ostream &sout) const {
size_t var_id = 0;
std::unordered_map<const VarHandleBase *, size_t> vars;

sout << "digraph G {\n";

IterAllVar(graph, [&](const VarHandleBase &var) {
auto *var_ptr = &var;
auto *var_handle_ptr = dynamic_cast<const VarHandle *>(var_ptr);
auto *dummy_ptr = dynamic_cast<const DummyVarHandle *>(var_ptr);

size_t cur_var_id = var_id++;
vars[var_ptr] = cur_var_id;

if (var_handle_ptr) {
sout << "var_" << cur_var_id << " [label=\"" << var_handle_ptr->name_
<< "\\n"
<< var_handle_ptr->place_ << "\\n"
<< var_handle_ptr->version_ << "\"]" << std::endl;
} else if (dummy_ptr) {
sout << "var_" << cur_var_id << " [label=\"dummy\"]" << std::endl;
}
});

size_t op_id = 0;
for (auto &op : graph.ops_) {
std::string op_name = "op_" + std::to_string(op_id++);
sout << op_name << " [label=\"" << op->Name() << "\", shape=rect]"
<< std::endl;
for (auto in : op->Inputs()) {
std::string var_name = "var_" + std::to_string(vars[in]);
sout << var_name << " -> " << op_name << std::endl;
}

for (auto out : op->Outputs()) {
std::string var_name = "var_" + std::to_string(vars[out]);
sout << op_name << " -> " << var_name << std::endl;
}
}

sout << "}\n";
}
} // namespace details
} // namespace framework
} // namespace paddle
Loading

0 comments on commit d9af153

Please sign in to comment.