From bac6179f01886fd2c455bd429cd5b11842138a6d Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Sun, 29 Nov 2020 18:43:11 -0500 Subject: [PATCH 1/4] update unique_id --- taichi/analysis/cfg_analysis.cpp | 8 +++- taichi/ir/analysis.h | 3 +- taichi/program/async_utils.cpp | 66 +++++++++++++++++++---------- taichi/program/async_utils.h | 14 +++--- taichi/program/ir_bank.cpp | 9 ++++ taichi/program/ir_bank.h | 15 +++++++ taichi/program/state_flow_graph.cpp | 13 +++--- taichi/program/state_flow_graph.h | 6 +++ 8 files changed, 94 insertions(+), 40 deletions(-) diff --git a/taichi/analysis/cfg_analysis.cpp b/taichi/analysis/cfg_analysis.cpp index b9c23cb078257..5b6595d702296 100644 --- a/taichi/analysis/cfg_analysis.cpp +++ b/taichi/analysis/cfg_analysis.cpp @@ -1,15 +1,19 @@ #include "taichi/ir/analysis.h" #include "taichi/ir/control_flow_graph.h" #include "taichi/program/async_utils.h" +#include "taichi/program/ir_bank.h" TLANG_NAMESPACE_BEGIN namespace irpass::analysis { -void get_meta_input_value_states(IRNode *root, TaskMeta *meta) { +void get_meta_input_value_states(IRNode *root, + TaskMeta *meta, + IRBank *ir_bank) { auto cfg = analysis::build_cfg(root); auto snodes = cfg->gather_loaded_snodes(); for (auto &snode : snodes) { - meta->input_states.emplace(snode, AsyncState::Type::value); + meta->input_states.insert( + ir_bank->get_async_state(snode, AsyncState::Type::value)); } } } // namespace irpass::analysis diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index c1aaf8597ec65..15bc8e612fa65 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -53,6 +53,7 @@ enum AliasResult { same, uncertain, different }; class ControlFlowGraph; struct TaskMeta; +class IRBank; // IR Analysis namespace irpass::analysis { @@ -75,7 +76,7 @@ std::unordered_map gather_uniquely_accessed_pointers( std::unique_ptr> gather_used_atomics( IRNode *root); std::vector get_load_pointers(Stmt *load_stmt); -void get_meta_input_value_states(IRNode *root, TaskMeta *meta); +void get_meta_input_value_states(IRNode *root, TaskMeta *meta, IRBank *ir_bank); Stmt *get_store_data(Stmt *store_stmt); std::vector get_store_destination(Stmt *store_stmt); bool has_store_or_atomic(IRNode *root, const std::vector &vars); diff --git a/taichi/program/async_utils.cpp b/taichi/program/async_utils.cpp index a9f3214d8ff4d..07e392f6f0e4b 100644 --- a/taichi/program/async_utils.cpp +++ b/taichi/program/async_utils.cpp @@ -66,7 +66,7 @@ std::string AsyncState::name() const { return prefix + "_" + type_name; } -std::size_t AsyncState::get_unique_id(void *ptr, AsyncState::Type type) { +std::size_t AsyncState::perfect_hash(void *ptr, AsyncState::Type type) { static_assert((int)Type::undefined < 8); static_assert(std::alignment_of() % 8 == 0); static_assert(std::alignment_of() % 8 == 0); @@ -145,7 +145,7 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { meta.name = t.kernel->name + "_" + offloaded_task_type_name(root_stmt->task_type); meta.type = root_stmt->task_type; - get_meta_input_value_states(root_stmt, &meta); + get_meta_input_value_states(root_stmt, &meta, ir_bank); meta.loop_unique = gather_uniquely_accessed_pointers(root_stmt); std::unordered_set activates, deactivates; @@ -157,7 +157,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { if (auto global_store = stmt->cast()) { if (auto ptr = global_store->ptr->cast()) { for (auto &snode : ptr->snodes.data) { - meta.output_states.emplace(snode, AsyncState::Type::value); + meta.output_states.insert( + ir_bank->get_async_state(snode, AsyncState::Type::value)); } } } @@ -166,7 +167,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { for (auto &snode : ptr->snodes.data) { // input_state is already handled in // get_meta_input_value_states(). - meta.output_states.emplace(snode, AsyncState::Type::value); + meta.output_states.insert( + ir_bank->get_async_state(snode, AsyncState::Type::value)); } } } @@ -182,12 +184,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { activates.insert(sn); for (auto &child : sn->ch) { TI_ASSERT(child->type == SNodeType::place); - meta.input_states.emplace(child.get(), AsyncState::Type::value); - meta.output_states.emplace(child.get(), AsyncState::Type::value); + meta.input_states.insert( + ir_bank->get_async_state(child.get(), AsyncState::Type::value)); + meta.output_states.insert( + ir_bank->get_async_state(child.get(), AsyncState::Type::value)); } } else if (snode_op->op_type == SNodeOpType::is_active || snode_op->op_type == SNodeOpType::length) { - meta.input_states.emplace(sn, AsyncState::Type::mask); + meta.input_states.insert( + ir_bank->get_async_state(sn, AsyncState::Type::mask)); } else { TI_NOT_IMPLEMENTED } @@ -210,12 +215,13 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { } } if (stmt->is()) { - auto as = AsyncState(t.kernel); + auto as = ir_bank->get_async_state(t.kernel); meta.input_states.insert(as); meta.output_states.insert(as); } if (auto clear_list = stmt->cast()) { - meta.output_states.emplace(clear_list->snode, AsyncState::Type::list); + meta.output_states.insert( + ir_bank->get_async_state(clear_list->snode, AsyncState::Type::list)); } return false; }); @@ -237,11 +243,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { // Do not record dense SNodes' mask states. if (s->need_activation()) { - meta.input_states.emplace(s, AsyncState::Type::mask); - meta.output_states.emplace(s, AsyncState::Type::mask); + meta.input_states.insert( + ir_bank->get_async_state(s, AsyncState::Type::mask)); + meta.output_states.insert( + ir_bank->get_async_state(s, AsyncState::Type::mask)); if (is_gc_able(s->type)) { - meta.input_states.emplace(s, AsyncState::Type::allocator); - meta.output_states.emplace(s, AsyncState::Type::allocator); + meta.input_states.insert( + ir_bank->get_async_state(s, AsyncState::Type::allocator)); + meta.output_states.insert( + ir_bank->get_async_state(s, AsyncState::Type::allocator)); } } s = s->parent; @@ -268,7 +278,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { continue; } if (s->type == SNodeType::place) { - meta.output_states.emplace(s, AsyncState::Type::value); + meta.output_states.insert( + ir_bank->get_async_state(s, AsyncState::Type::value)); } else { for (auto &child : s->ch) { if (deactivates.count(child.get()) == 0) { @@ -288,22 +299,31 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) { if (root_stmt->task_type == OffloadedTaskType::listgen) { TI_ASSERT(root_stmt->snode->parent); meta.snode = root_stmt->snode; - meta.input_states.emplace(root_stmt->snode->parent, AsyncState::Type::list); - meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list); + meta.input_states.insert(ir_bank->get_async_state(root_stmt->snode->parent, + AsyncState::Type::list)); + meta.input_states.insert( + ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::list)); if (root_stmt->snode->need_activation()) { - meta.input_states.emplace(root_stmt->snode, AsyncState::Type::mask); + meta.input_states.insert( + ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::mask)); } - meta.output_states.emplace(root_stmt->snode, AsyncState::Type::list); + meta.output_states.insert( + ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::list)); } else if (root_stmt->task_type == OffloadedTaskType::struct_for) { meta.snode = root_stmt->snode; - meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list); + meta.input_states.insert( + ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::list)); } else if ((root_stmt->task_type == OffloadedTaskType::gc) && (is_gc_able(root_stmt->snode->type))) { meta.snode = root_stmt->snode; - meta.input_states.emplace(root_stmt->snode, AsyncState::Type::mask); - meta.input_states.emplace(root_stmt->snode, AsyncState::Type::allocator); - meta.output_states.emplace(root_stmt->snode, AsyncState::Type::mask); - meta.output_states.emplace(root_stmt->snode, AsyncState::Type::allocator); + meta.input_states.insert( + ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::mask)); + meta.input_states.insert(ir_bank->get_async_state( + root_stmt->snode, AsyncState::Type::allocator)); + meta.output_states.insert( + ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::mask)); + meta.output_states.insert(ir_bank->get_async_state( + root_stmt->snode, AsyncState::Type::allocator)); insert_value_states_top_down(root_stmt->snode); } diff --git a/taichi/program/async_utils.h b/taichi/program/async_utils.h index 598d2af17b986..508f55df7cae2 100644 --- a/taichi/program/async_utils.h +++ b/taichi/program/async_utils.h @@ -89,17 +89,13 @@ struct AsyncState { AsyncState() = default; // For SNode - AsyncState(SNode *snode, Type type) - : snode_or_global_tmp(snode), - type(type), - unique_id(get_unique_id(snode, type)) { + AsyncState(SNode *snode, Type type, std::size_t unique_id) + : snode_or_global_tmp(snode), type(type), unique_id(unique_id) { } // For global temporaries - AsyncState(Kernel *kernel) - : snode_or_global_tmp(kernel), - type(Type::value), - unique_id(get_unique_id(kernel, type)) { + AsyncState(Kernel *kernel, std::size_t unique_id) + : snode_or_global_tmp(kernel), type(Type::value), unique_id(unique_id) { } bool operator<(const AsyncState &other) const { @@ -124,7 +120,7 @@ struct AsyncState { return std::get(snode_or_global_tmp); } - static std::size_t get_unique_id(void *ptr, Type type); + static std::size_t perfect_hash(void *ptr, Type type); }; struct TaskFusionMeta { diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index 65deef95574b5..f65e520696197 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -194,4 +194,13 @@ std::pair IRBank::optimize_dse( return std::make_pair(ret_handle, false); } +AsyncState IRBank::get_async_state(SNode *snode, AsyncState::Type type) { + return AsyncState(snode, type, lookup_async_state_id(snode, type)); +} + +AsyncState IRBank::get_async_state(Kernel *kernel) { + return AsyncState(kernel, + lookup_async_state_id(kernel, AsyncState::Type::value)); +} + TLANG_NAMESPACE_END diff --git a/taichi/program/ir_bank.h b/taichi/program/ir_bank.h index 6e2079aa0cbec..c96f4a86f51fd 100644 --- a/taichi/program/ir_bank.h +++ b/taichi/program/ir_bank.h @@ -33,6 +33,10 @@ class IRBank { std::unordered_map meta_bank_; std::unordered_map fusion_meta_bank_; + AsyncState get_async_state(SNode *snode, AsyncState::Type type); + + AsyncState get_async_state(Kernel *kernel); + private: std::unordered_map hash_bank_; std::unordered_map> ir_bank_; @@ -72,8 +76,19 @@ class IRBank { } }; }; + std::unordered_map optimize_dse_bank_; + std::unordered_map async_state_to_unique_id_; + + std::size_t lookup_async_state_id(void *ptr, AsyncState::Type type) { + auto h = AsyncState::perfect_hash(ptr, type); + if (async_state_to_unique_id_.find(h) == async_state_to_unique_id_.end()) { + async_state_to_unique_id_.insert( + std::make_pair(h, async_state_to_unique_id_.size())); + } + return async_state_to_unique_id_[h]; + } }; TLANG_NAMESPACE_END diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 8c4f46b4f6f94..036e3dda4a7f2 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -308,13 +308,15 @@ bool StateFlowGraph::optimize_listgen() { // Test if two list generations share the same mask and parent list auto snode = node_a->meta->snode; - auto list_state = AsyncState{snode, AsyncState::Type::list}; + auto list_state = + ir_bank_->get_async_state(snode, AsyncState::Type::list); auto parent_list_state = - AsyncState{snode->parent, AsyncState::Type::list}; + ir_bank_->get_async_state(snode->parent, AsyncState::Type::list); if (snode->need_activation()) { // Needs mask state - auto mask_state = AsyncState{snode, AsyncState::Type::mask}; + auto mask_state = + ir_bank_->get_async_state(snode, AsyncState::Type::mask); TI_ASSERT(get_or_insert(node_a->input_edges, mask_state).size() == 1); TI_ASSERT(get_or_insert(node_b->input_edges, mask_state).size() == 1); @@ -762,7 +764,8 @@ std::string StateFlowGraph::dump_dot(const std::optional &rankdir, std::stringstream ss; // TODO: expose an API that allows users to highlight a single state - AsyncState highlight_state{nullptr, AsyncState::Type::value}; + AsyncState highlight_state = + ir_bank_->get_async_state(nullptr, AsyncState::Type::value); ss << "digraph {\n"; auto node_id = [](const SFGNode *n) { @@ -1309,7 +1312,7 @@ bool StateFlowGraph::demote_activation() { for (int i = 1; i < (int)nodes_.size(); i++) { Node *node = nodes_[i].get(); auto snode = node->meta->snode; - auto list_state = AsyncState(snode, AsyncState::Type::list); + auto list_state = ir_bank_->get_async_state(snode, AsyncState::Type::list); // Currently we handle struct for only // TODO: handle serial and range for diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index 742e01524b6d3..a406466d41aa7 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -181,6 +181,12 @@ class StateFlowGraph { void benchmark_rebuild_graph(); + std::size_t lookup_async_state_id(void *ptr, AsyncState::Type); + + AsyncState get_async_state(SNode *snode, AsyncState::Type type); + + AsyncState get_async_state(Kernel *kernel); + private: std::vector> nodes_; Node *initial_node_; // The initial node holds all the initial states. From f11aea2a7a82cee8d099954143aaddce792b2a36 Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Sun, 29 Nov 2020 19:34:31 -0500 Subject: [PATCH 2/4] use vector instead of unordered map; 7% faster --- taichi/program/async_engine.cpp | 1 + taichi/program/ir_bank.cpp | 14 ++++++++-- taichi/program/ir_bank.h | 5 ++++ taichi/program/state_flow_graph.cpp | 43 ++++++++++++++++++++--------- taichi/program/state_flow_graph.h | 6 ++-- 5 files changed, 50 insertions(+), 19 deletions(-) diff --git a/taichi/program/async_engine.cpp b/taichi/program/async_engine.cpp index 5cf0baf0eaf7d..342f63655d56e 100644 --- a/taichi/program/async_engine.cpp +++ b/taichi/program/async_engine.cpp @@ -175,6 +175,7 @@ AsyncEngine::AsyncEngine(Program *program, : queue(&ir_bank_, compile_to_backend), program(program), sfg(std::make_unique(this, &ir_bank_)) { + ir_bank_.set_sfg(sfg.get()); } void AsyncEngine::launch(Kernel *kernel, Context &context) { diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index f65e520696197..d63ffcc806e85 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -5,6 +5,7 @@ #include "taichi/ir/transforms.h" #include "taichi/ir/analysis.h" #include "taichi/program/kernel.h" +#include "taichi/program/state_flow_graph.h" TLANG_NAMESPACE_BEGIN @@ -195,12 +196,19 @@ std::pair IRBank::optimize_dse( } AsyncState IRBank::get_async_state(SNode *snode, AsyncState::Type type) { - return AsyncState(snode, type, lookup_async_state_id(snode, type)); + auto id = lookup_async_state_id(snode, type); + sfg_->populate_latest_state_owner(id); + return AsyncState(snode, type, id); } AsyncState IRBank::get_async_state(Kernel *kernel) { - return AsyncState(kernel, - lookup_async_state_id(kernel, AsyncState::Type::value)); + auto id = lookup_async_state_id(kernel, AsyncState::Type::value); + sfg_->populate_latest_state_owner(id); + return AsyncState(kernel, id); +} + +void IRBank::set_sfg(StateFlowGraph *sfg) { + sfg_ = sfg; } TLANG_NAMESPACE_END diff --git a/taichi/program/ir_bank.h b/taichi/program/ir_bank.h index c96f4a86f51fd..e79072553ee42 100644 --- a/taichi/program/ir_bank.h +++ b/taichi/program/ir_bank.h @@ -6,6 +6,8 @@ TLANG_NAMESPACE_BEGIN +class StateFlowGraph; + class IRBank { public: uint64 get_hash(IRNode *ir); @@ -33,11 +35,14 @@ class IRBank { std::unordered_map meta_bank_; std::unordered_map fusion_meta_bank_; + void set_sfg(StateFlowGraph *sfg); + AsyncState get_async_state(SNode *snode, AsyncState::Type type); AsyncState get_async_state(Kernel *kernel); private: + StateFlowGraph *sfg_; std::unordered_map hash_bank_; std::unordered_map> ir_bank_; std::vector> trash_bin; // prevent IR from deleted diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index 036e3dda4a7f2..b41fb3caea2b1 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -139,7 +139,8 @@ void StateFlowGraph::clear() { // TODO: GC here? nodes_.resize(1); // Erase all nodes except the initial one initial_node_->output_edges.clear(); - latest_state_owner_.clear(); + std::fill(latest_state_owner_.begin(), latest_state_owner_.end(), + initial_node_); latest_state_readers_.clear(); first_pending_task_index_ = 1; @@ -150,7 +151,7 @@ void StateFlowGraph::mark_pending_tasks_as_executed() { std::vector> new_nodes; std::unordered_set state_owners; for (auto &owner : latest_state_owner_) { - state_owners.insert(state_owners.end(), owner.second); + state_owners.insert(state_owners.end(), owner); } for (auto &node : nodes_) { if (node->is_initial_node || state_owners.count(node.get()) > 0) { @@ -224,22 +225,20 @@ void StateFlowGraph::insert_tasks(const std::vector &records, void StateFlowGraph::insert_node(std::unique_ptr &&node) { for (auto input_state : node->meta->input_states) { - if (latest_state_owner_.find(input_state) == latest_state_owner_.end()) { - latest_state_owner_[input_state] = initial_node_; - } - insert_edge(latest_state_owner_[input_state], node.get(), input_state); + insert_edge(latest_state_owner_[input_state.unique_id], node.get(), + input_state); } for (auto output_state : node->meta->output_states) { if (get_or_insert(latest_state_readers_, output_state).empty()) { - if (latest_state_owner_.find(output_state) != latest_state_owner_.end()) { + if (latest_state_owner_[output_state.unique_id] != initial_node_) { // insert a WAW dependency edge - insert_edge(latest_state_owner_[output_state], node.get(), + insert_edge(latest_state_owner_[output_state.unique_id], node.get(), output_state); } else { insert(latest_state_readers_, output_state, initial_node_); } } - latest_state_owner_[output_state] = node.get(); + latest_state_owner_[output_state.unique_id] = node.get(); for (auto *d : get_or_insert(latest_state_readers_, output_state)) { // insert a WAR dependency edge insert_edge(d, node.get(), output_state); @@ -795,7 +794,7 @@ std::string StateFlowGraph::dump_dot(const std::optional &rankdir, std::unordered_set nodes_with_embedded_states; // TODO: make this configurable for (const auto &p : latest_state_owner_) { - latest_state_nodes.insert(p.second); + latest_state_nodes.insert(p); } bool highlight_single_state = false; @@ -1073,8 +1072,8 @@ void StateFlowGraph::delete_nodes( } for (auto &s : latest_state_owner_) { - if (nodes_to_delete.find(s.second) != nodes_to_delete.end()) { - s.second = initial_node_; + if (nodes_to_delete.find(s) != nodes_to_delete.end()) { + s = initial_node_; } } @@ -1104,7 +1103,7 @@ bool StateFlowGraph::optimize_dead_store() { // only focus on "value" states. continue; } - if (latest_state_owner_[s] == task) { + if (latest_state_owner_[s.unique_id] == task) { // Cannot eliminate the latest write, because it may form a state-flow // with the later kernel launches. // @@ -1379,6 +1378,24 @@ void StateFlowGraph::benchmark_rebuild_graph() { nodes_.size(), rebuild_t * 1e7, 1e7 * rebuild_t / nodes_.size()); } } +AsyncState StateFlowGraph::get_async_state(SNode *snode, + AsyncState::Type type) { + return ir_bank_->get_async_state(snode, type); +} + +AsyncState StateFlowGraph::get_async_state(Kernel *kernel) { + return ir_bank_->get_async_state(kernel); +} + +void StateFlowGraph::populate_latest_state_owner(std::size_t h) { + if (h >= latest_state_owner_.size()) { + std::size_t old_size = latest_state_owner_.size(); + latest_state_owner_.resize(h + 1); + for (int i = old_size; i < latest_state_owner_.size(); i++) { + latest_state_owner_[i] = initial_node_; + } + } +} void async_print_sfg() { get_current_program().async_engine->sfg->print(); diff --git a/taichi/program/state_flow_graph.h b/taichi/program/state_flow_graph.h index a406466d41aa7..f0a1a330980b5 100644 --- a/taichi/program/state_flow_graph.h +++ b/taichi/program/state_flow_graph.h @@ -181,18 +181,18 @@ class StateFlowGraph { void benchmark_rebuild_graph(); - std::size_t lookup_async_state_id(void *ptr, AsyncState::Type); - AsyncState get_async_state(SNode *snode, AsyncState::Type type); AsyncState get_async_state(Kernel *kernel); + void populate_latest_state_owner(std::size_t id); + private: std::vector> nodes_; Node *initial_node_; // The initial node holds all the initial states. int first_pending_task_index_; TaskMeta initial_meta_; - std::unordered_map latest_state_owner_; + std::vector latest_state_owner_; StateToNodesMap latest_state_readers_; std::unordered_map task_name_to_launch_ids_; IRBank *ir_bank_; From 2d9fab50f76c3906344bb739f4a4e6257fece6ba Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Sun, 29 Nov 2020 19:49:32 -0500 Subject: [PATCH 3/4] finalize --- taichi/program/state_flow_graph.cpp | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/taichi/program/state_flow_graph.cpp b/taichi/program/state_flow_graph.cpp index b41fb3caea2b1..9aee0319f7412 100644 --- a/taichi/program/state_flow_graph.cpp +++ b/taichi/program/state_flow_graph.cpp @@ -150,7 +150,7 @@ void StateFlowGraph::clear() { void StateFlowGraph::mark_pending_tasks_as_executed() { std::vector> new_nodes; std::unordered_set state_owners; - for (auto &owner : latest_state_owner_) { + for (auto owner : latest_state_owner_) { state_owners.insert(state_owners.end(), owner); } for (auto &node : nodes_) { @@ -307,15 +307,13 @@ bool StateFlowGraph::optimize_listgen() { // Test if two list generations share the same mask and parent list auto snode = node_a->meta->snode; - auto list_state = - ir_bank_->get_async_state(snode, AsyncState::Type::list); + auto list_state = get_async_state(snode, AsyncState::Type::list); auto parent_list_state = - ir_bank_->get_async_state(snode->parent, AsyncState::Type::list); + get_async_state(snode->parent, AsyncState::Type::list); if (snode->need_activation()) { // Needs mask state - auto mask_state = - ir_bank_->get_async_state(snode, AsyncState::Type::mask); + auto mask_state = get_async_state(snode, AsyncState::Type::mask); TI_ASSERT(get_or_insert(node_a->input_edges, mask_state).size() == 1); TI_ASSERT(get_or_insert(node_b->input_edges, mask_state).size() == 1); @@ -764,7 +762,7 @@ std::string StateFlowGraph::dump_dot(const std::optional &rankdir, // TODO: expose an API that allows users to highlight a single state AsyncState highlight_state = - ir_bank_->get_async_state(nullptr, AsyncState::Type::value); + get_async_state(nullptr, AsyncState::Type::value); ss << "digraph {\n"; auto node_id = [](const SFGNode *n) { @@ -793,7 +791,7 @@ std::string StateFlowGraph::dump_dot(const std::optional &rankdir, std::unordered_set latest_state_nodes; std::unordered_set nodes_with_embedded_states; // TODO: make this configurable - for (const auto &p : latest_state_owner_) { + for (auto p : latest_state_owner_) { latest_state_nodes.insert(p); } @@ -1311,7 +1309,7 @@ bool StateFlowGraph::demote_activation() { for (int i = 1; i < (int)nodes_.size(); i++) { Node *node = nodes_[i].get(); auto snode = node->meta->snode; - auto list_state = ir_bank_->get_async_state(snode, AsyncState::Type::list); + auto list_state = get_async_state(snode, AsyncState::Type::list); // Currently we handle struct for only // TODO: handle serial and range for @@ -1378,6 +1376,7 @@ void StateFlowGraph::benchmark_rebuild_graph() { nodes_.size(), rebuild_t * 1e7, 1e7 * rebuild_t / nodes_.size()); } } + AsyncState StateFlowGraph::get_async_state(SNode *snode, AsyncState::Type type) { return ir_bank_->get_async_state(snode, type); @@ -1387,10 +1386,10 @@ AsyncState StateFlowGraph::get_async_state(Kernel *kernel) { return ir_bank_->get_async_state(kernel); } -void StateFlowGraph::populate_latest_state_owner(std::size_t h) { - if (h >= latest_state_owner_.size()) { +void StateFlowGraph::populate_latest_state_owner(std::size_t id) { + if (id >= latest_state_owner_.size()) { std::size_t old_size = latest_state_owner_.size(); - latest_state_owner_.resize(h + 1); + latest_state_owner_.resize(id + 1); for (int i = old_size; i < latest_state_owner_.size(); i++) { latest_state_owner_[i] = initial_node_; } From 8aa808cd1cf87f78c3697db305ecd3abab97da3b Mon Sep 17 00:00:00 2001 From: Yuanming Hu Date: Sun, 29 Nov 2020 19:55:19 -0500 Subject: [PATCH 4/4] finalize --- taichi/program/ir_bank.cpp | 9 +++++++++ taichi/program/ir_bank.h | 9 +-------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/taichi/program/ir_bank.cpp b/taichi/program/ir_bank.cpp index d63ffcc806e85..2e1e97d811e7f 100644 --- a/taichi/program/ir_bank.cpp +++ b/taichi/program/ir_bank.cpp @@ -211,4 +211,13 @@ void IRBank::set_sfg(StateFlowGraph *sfg) { sfg_ = sfg; } +std::size_t IRBank::lookup_async_state_id(void *ptr, AsyncState::Type type) { + auto h = AsyncState::perfect_hash(ptr, type); + if (async_state_to_unique_id_.find(h) == async_state_to_unique_id_.end()) { + async_state_to_unique_id_.insert( + std::make_pair(h, async_state_to_unique_id_.size())); + } + return async_state_to_unique_id_[h]; +} + TLANG_NAMESPACE_END diff --git a/taichi/program/ir_bank.h b/taichi/program/ir_bank.h index e79072553ee42..2b6c6f2bd2799 100644 --- a/taichi/program/ir_bank.h +++ b/taichi/program/ir_bank.h @@ -86,14 +86,7 @@ class IRBank { optimize_dse_bank_; std::unordered_map async_state_to_unique_id_; - std::size_t lookup_async_state_id(void *ptr, AsyncState::Type type) { - auto h = AsyncState::perfect_hash(ptr, type); - if (async_state_to_unique_id_.find(h) == async_state_to_unique_id_.end()) { - async_state_to_unique_id_.insert( - std::make_pair(h, async_state_to_unique_id_.size())); - } - return async_state_to_unique_id_[h]; - } + std::size_t lookup_async_state_id(void *ptr, AsyncState::Type type); }; TLANG_NAMESPACE_END