diff --git a/misc/visualize_state_flow_graph.py b/misc/visualize_state_flow_graph.py index 3b32131d91b96..c379b06f304a5 100644 --- a/misc/visualize_state_flow_graph.py +++ b/misc/visualize_state_flow_graph.py @@ -25,3 +25,4 @@ def bar(): bar() ti.core.print_sfg() +print(ti.dump_dot()) diff --git a/python/taichi/misc/util.py b/python/taichi/misc/util.py index 224bcee70bf48..1dc98fd17e2f6 100644 --- a/python/taichi/misc/util.py +++ b/python/taichi/misc/util.py @@ -249,12 +249,22 @@ def veci(*args, **kwargs): return core_veci(*args, **kwargs) +def dump_dot(filepath=None): + from taichi.core import ti_core + d = ti_core.dump_dot() + if filepath is not None: + with open(filepath, 'w') as fh: + fh.write(d) + return d + + __all__ = [ 'vec', 'veci', 'core_vec', 'core_veci', 'deprecated', + 'dump_dot', 'obsolete', 'get_traceback', 'set_gdb_trigger', diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 5276a11995159..95c76d37d16d5 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -1,20 +1,40 @@ #include "taichi/program/state_flow_graph.h" +#include +#include + TLANG_NAMESPACE_BEGIN +std::string StateFlowGraph::Node::string() const { + return fmt::format("[node: {}:{}]", task_name, launch_id); +} + +StateFlowGraph::StateFlowGraph() { + nodes_.push_back(std::make_unique()); + initial_node_ = nodes_.back().get(); + initial_node_->task_name = "initial_state"; + initial_node_->launch_id = 0; +} + void StateFlowGraph::insert_task(const TaskMeta &task_meta) { auto node = std::make_unique(); - node->kernel_name = task_meta.kernel_name; + node->task_name = task_meta.kernel_name; + { + int &id = task_name_to_launch_ids_[node->task_name]; + node->launch_id = id; + ++id; + } for (auto input_state : task_meta.input_states) { - if (latest_state_owner.find(input_state) == latest_state_owner.end()) { - latest_state_owner[input_state] = initial_node; + if (latest_state_owner_.find(input_state) == latest_state_owner_.end()) { + latest_state_owner_[input_state] = initial_node_; } - insert_state_flow(latest_state_owner[input_state], node.get(), input_state); + insert_state_flow(latest_state_owner_[input_state], node.get(), + input_state); } for (auto output_state : task_meta.output_states) { - latest_state_owner[output_state] = node.get(); + latest_state_owner_[output_state] = node.get(); } - nodes.push_back(std::move(node)); + nodes_.push_back(std::move(node)); } void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) { @@ -27,25 +47,74 @@ void StateFlowGraph::insert_state_flow(Node *from, Node *to, AsyncState state) { void StateFlowGraph::print_edges(const StateFlowGraph::Edges &edges) { for (auto &edge : edges) { auto input_node = edge.second; - fmt::print(" {} -> node {} @ {}\n", edge.first.name(), - input_node->kernel_name, (void *)input_node); + fmt::print(" {} -> {}\n", edge.first.name(), input_node->string()); } } void StateFlowGraph::print() { fmt::print("=== State Flow Graph ===\n"); - for (auto &node : nodes) { - fmt::print("Node {} {}\n", node->kernel_name, (void *)node.get()); + for (auto &node : nodes_) { + fmt::print("{}\n", node->string()); if (!node->input_edges.empty()) { - fmt::print(" Inputs:\n", node->kernel_name, (void *)node.get()); + fmt::print(" Inputs:\n"); print_edges(node->input_edges); } if (!node->output_edges.empty()) { - fmt::print(" Outputs:\n", node->kernel_name, (void *)node.get()); + fmt::print(" Outputs:\n"); print_edges(node->output_edges); } } fmt::print("=======================\n"); } +std::string StateFlowGraph::dump_dot() { + using SFGNode = StateFlowGraph::Node; + std::stringstream ss; + ss << "digraph {\n"; + auto node_id = [](const SFGNode *n) { + // https://graphviz.org/doc/info/lang.html ID naming + return fmt::format("n_{}_{}", n->task_name, n->launch_id); + }; + // Specify the node styles + std::unordered_set latest_state_nodes; + for (const auto &p : latest_state_owner_) { + latest_state_nodes.insert(p.second); + } + for (const auto &nd : nodes_) { + const auto *n = nd.get(); + ss << " " << fmt::format("{} [label=\"{}\"", node_id(n), n->string()); + if (n == initial_node_) { + ss << ",shape=box"; + } else if (latest_state_nodes.find(n) != latest_state_nodes.end()) { + ss << ",peripheries=2"; + } + ss << "]\n"; + } + ss << "\n"; + { + // DFS + std::unordered_set visited; + std::vector stack; + stack.push_back(initial_node_); + while (!stack.empty()) { + auto *from = stack.back(); + stack.pop_back(); + if (visited.find(from) == visited.end()) { + visited.insert(from); + for (const auto &p : from->output_edges) { + auto *to = p.second; + stack.push_back(to); + + ss << " " + << fmt::format("{} -> {} [label=\"{}\"]", node_id(from), + node_id(to), p.first.name()) + << '\n'; + } + } + } + } + ss << "}\n"; // closes "dirgraph {" + return ss.str(); +} + TLANG_NAMESPACE_END diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index 173c96ddbc110..cf6dc4028e9fd 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -24,33 +24,33 @@ class StateFlowGraph { struct Node { // TODO: make use of IRHandle here IRNode *root; - std::string kernel_name; + std::string task_name; + // Incremental ID to identify the i-th launch of the task. + int launch_id; + // For |input_edges|, each state could map to exactly one node. + // For |output_edges|, each state could map to at least one node. Edges input_edges, output_edges; - }; - - StateToNodeMapping latest_state_owner; - - std::vector> nodes; - Node *initial_node; // The initial node holds all the initial states. + std::string string() const; + }; - StateFlowGraph() { - nodes.push_back(std::make_unique()); - initial_node = nodes.back().get(); - initial_node->kernel_name = "initial_state"; - } + StateFlowGraph(); void print_edges(const Edges &edges); void print(); - void dump_dot(const std::string &fn) { - // TODO: export the graph to Dot format for GraphViz - } + std::string dump_dot(); void insert_task(const TaskMeta &task_meta); void insert_state_flow(Node *from, Node *to, AsyncState state); + + private: + std::vector> nodes_; + Node *initial_node_; // The initial node holds all the initial states. + StateToNodeMapping latest_state_owner_; + std::unordered_map task_name_to_launch_ids_; }; TLANG_NAMESPACE_END diff --git a/taichi/python/export_lang.cpp b/taichi/python/export_lang.cpp index 304422c7f7e88..904894fda0bf8 100644 --- a/taichi/python/export_lang.cpp +++ b/taichi/python/export_lang.cpp @@ -649,6 +649,8 @@ void export_lang(py::module &m) { }); m.def("print_sfg", [] { get_current_program().async_engine->sfg->print(); }); + m.def("dump_dot", + [] { return get_current_program().async_engine->sfg->dump_dot(); }); } TI_NAMESPACE_END