Skip to content

Commit

Permalink
[async] Generate DOT graph for StateFlowGraph (#1852)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Sep 7, 2020
1 parent 32bc391 commit b8a544d
Show file tree
Hide file tree
Showing 5 changed files with 109 additions and 27 deletions.
1 change: 1 addition & 0 deletions misc/visualize_state_flow_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,4 @@ def bar():
bar()

ti.core.print_sfg()
print(ti.dump_dot())
10 changes: 10 additions & 0 deletions python/taichi/misc/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,12 +249,22 @@ def veci(*args, **kwargs):
return core_veci(*args, **kwargs)


def dump_dot(filepath=None):
from taichi.core import ti_core
d = ti_core.dump_dot()
if filepath is not None:
with open(filepath, 'w') as fh:
fh.write(d)
return d


__all__ = [
'vec',
'veci',
'core_vec',
'core_veci',
'deprecated',
'dump_dot',
'obsolete',
'get_traceback',
'set_gdb_trigger',
Expand Down
93 changes: 81 additions & 12 deletions taichi/program/state_flow_graph.cpp
Original file line number Diff line number Diff line change
@@ -1,20 +1,40 @@
#include "taichi/program/state_flow_graph.h"

#include <sstream>
#include <unordered_set>

TLANG_NAMESPACE_BEGIN

std::string StateFlowGraph::Node::string() const {
return fmt::format("[node: {}:{}]", task_name, launch_id);
}

StateFlowGraph::StateFlowGraph() {
nodes_.push_back(std::make_unique<Node>());
initial_node_ = nodes_.back().get();
initial_node_->task_name = "initial_state";
initial_node_->launch_id = 0;
}

void StateFlowGraph::insert_task(const TaskMeta &task_meta) {
auto node = std::make_unique<Node>();
node->kernel_name = task_meta.kernel_name;
node->task_name = task_meta.kernel_name;
{
int &id = task_name_to_launch_ids_[node->task_name];
node->launch_id = id;
++id;
}
for (auto input_state : task_meta.input_states) {
if (latest_state_owner.find(input_state) == latest_state_owner.end()) {
latest_state_owner[input_state] = initial_node;
if (latest_state_owner_.find(input_state) == latest_state_owner_.end()) {
latest_state_owner_[input_state] = initial_node_;
}
insert_state_flow(latest_state_owner[input_state], node.get(), input_state);
insert_state_flow(latest_state_owner_[input_state], node.get(),
input_state);
}
for (auto output_state : task_meta.output_states) {
latest_state_owner[output_state] = node.get();
latest_state_owner_[output_state] = node.get();
}
nodes.push_back(std::move(node));
nodes_.push_back(std::move(node));
}

void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) {
Expand All @@ -27,25 +47,74 @@ void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) {
void StateFlowGraph::print_edges(const StateFlowGraph::Edges &edges) {
for (auto &edge : edges) {
auto input_node = edge.second;
fmt::print(" {} -> node {} @ {}\n", edge.first.name(),
input_node->kernel_name, (void *)input_node);
fmt::print(" {} -> {}\n", edge.first.name(), input_node->string());
}
}

void StateFlowGraph::print() {
fmt::print("=== State Flow Graph ===\n");
for (auto &node : nodes) {
fmt::print("Node {} {}\n", node->kernel_name, (void *)node.get());
for (auto &node : nodes_) {
fmt::print("{}\n", node->string());
if (!node->input_edges.empty()) {
fmt::print(" Inputs:\n", node->kernel_name, (void *)node.get());
fmt::print(" Inputs:\n");
print_edges(node->input_edges);
}
if (!node->output_edges.empty()) {
fmt::print(" Outputs:\n", node->kernel_name, (void *)node.get());
fmt::print(" Outputs:\n");
print_edges(node->output_edges);
}
}
fmt::print("=======================\n");
}

std::string StateFlowGraph::dump_dot() {
using SFGNode = StateFlowGraph::Node;
std::stringstream ss;
ss << "digraph {\n";
auto node_id = [](const SFGNode *n) {
// https://graphviz.org/doc/info/lang.html ID naming
return fmt::format("n_{}_{}", n->task_name, n->launch_id);
};
// Specify the node styles
std::unordered_set<const SFGNode *> latest_state_nodes;
for (const auto &p : latest_state_owner_) {
latest_state_nodes.insert(p.second);
}
for (const auto &nd : nodes_) {
const auto *n = nd.get();
ss << " " << fmt::format("{} [label=\"{}\"", node_id(n), n->string());
if (n == initial_node_) {
ss << ",shape=box";
} else if (latest_state_nodes.find(n) != latest_state_nodes.end()) {
ss << ",peripheries=2";
}
ss << "]\n";
}
ss << "\n";
{
// DFS
std::unordered_set<const SFGNode *> visited;
std::vector<const SFGNode *> stack;
stack.push_back(initial_node_);
while (!stack.empty()) {
auto *from = stack.back();
stack.pop_back();
if (visited.find(from) == visited.end()) {
visited.insert(from);
for (const auto &p : from->output_edges) {
auto *to = p.second;
stack.push_back(to);

ss << " "
<< fmt::format("{} -> {} [label=\"{}\"]", node_id(from),
node_id(to), p.first.name())
<< '\n';
}
}
}
}
ss << "}\n"; // closes "dirgraph {"
return ss.str();
}

TLANG_NAMESPACE_END
30 changes: 15 additions & 15 deletions taichi/program/state_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,33 +24,33 @@ class StateFlowGraph {
struct Node {
// TODO: make use of IRHandle here
IRNode *root;
std::string kernel_name;
std::string task_name;
// Incremental ID to identify the i-th launch of the task.
int launch_id;
// For |input_edges|, each state could map to exactly one node.
// For |output_edges|, each state could map to at least one node.
Edges input_edges, output_edges;
};

StateToNodeMapping latest_state_owner;

std::vector<std::unique_ptr<Node>> nodes;

Node *initial_node; // The initial node holds all the initial states.
std::string string() const;
};

StateFlowGraph() {
nodes.push_back(std::make_unique<Node>());
initial_node = nodes.back().get();
initial_node->kernel_name = "initial_state";
}
StateFlowGraph();

void print_edges(const Edges &edges);

void print();

void dump_dot(const std::string &fn) {
// TODO: export the graph to Dot format for GraphViz
}
std::string dump_dot();

void insert_task(const TaskMeta &task_meta);

void insert_state_flow(Node *from, Node *to, AsyncState state);

private:
std::vector<std::unique_ptr<Node>> nodes_;
Node *initial_node_; // The initial node holds all the initial states.
StateToNodeMapping latest_state_owner_;
std::unordered_map<std::string, int> task_name_to_launch_ids_;
};

TLANG_NAMESPACE_END
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -649,6 +649,8 @@ void export_lang(py::module &m) {
});

m.def("print_sfg", [] { get_current_program().async_engine->sfg->print(); });
m.def("dump_dot",
[] { return get_current_program().async_engine->sfg->dump_dot(); });
}

TI_NAMESPACE_END

0 comments on commit b8a544d

Please sign in to comment.