Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] Compact AsyncState id for faster SFG rebuilding #2071

Merged
merged 4 commits into from
Nov 30, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions taichi/analysis/cfg_analysis.cpp
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/control_flow_graph.h"
#include "taichi/program/async_utils.h"
#include "taichi/program/ir_bank.h"

TLANG_NAMESPACE_BEGIN

namespace irpass::analysis {
void get_meta_input_value_states(IRNode *root, TaskMeta *meta) {
void get_meta_input_value_states(IRNode *root,
TaskMeta *meta,
IRBank *ir_bank) {
auto cfg = analysis::build_cfg(root);
auto snodes = cfg->gather_loaded_snodes();
for (auto &snode : snodes) {
meta->input_states.emplace(snode, AsyncState::Type::value);
meta->input_states.insert(
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
} // namespace irpass::analysis
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ enum AliasResult { same, uncertain, different };
class ControlFlowGraph;

struct TaskMeta;
class IRBank;

// IR Analysis
namespace irpass::analysis {
Expand All @@ -75,7 +76,7 @@ std::unordered_map<SNode *, GlobalPtrStmt *> gather_uniquely_accessed_pointers(
std::unique_ptr<std::unordered_set<AtomicOpStmt *>> gather_used_atomics(
IRNode *root);
std::vector<Stmt *> get_load_pointers(Stmt *load_stmt);
void get_meta_input_value_states(IRNode *root, TaskMeta *meta);
void get_meta_input_value_states(IRNode *root, TaskMeta *meta, IRBank *ir_bank);
Stmt *get_store_data(Stmt *store_stmt);
std::vector<Stmt *> get_store_destination(Stmt *store_stmt);
bool has_store_or_atomic(IRNode *root, const std::vector<Stmt *> &vars);
Expand Down
1 change: 1 addition & 0 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ AsyncEngine::AsyncEngine(Program *program,
: queue(&ir_bank_, compile_to_backend),
program(program),
sfg(std::make_unique<StateFlowGraph>(this, &ir_bank_)) {
ir_bank_.set_sfg(sfg.get());
}

void AsyncEngine::launch(Kernel *kernel, Context &context) {
Expand Down
66 changes: 43 additions & 23 deletions taichi/program/async_utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ std::string AsyncState::name() const {
return prefix + "_" + type_name;
}

std::size_t AsyncState::get_unique_id(void *ptr, AsyncState::Type type) {
std::size_t AsyncState::perfect_hash(void *ptr, AsyncState::Type type) {
static_assert((int)Type::undefined < 8);
static_assert(std::alignment_of<SNode>() % 8 == 0);
static_assert(std::alignment_of<Kernel>() % 8 == 0);
Expand Down Expand Up @@ -145,7 +145,7 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
meta.name =
t.kernel->name + "_" + offloaded_task_type_name(root_stmt->task_type);
meta.type = root_stmt->task_type;
get_meta_input_value_states(root_stmt, &meta);
get_meta_input_value_states(root_stmt, &meta, ir_bank);
meta.loop_unique = gather_uniquely_accessed_pointers(root_stmt);

std::unordered_set<SNode *> activates, deactivates;
Expand All @@ -157,7 +157,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
if (auto global_store = stmt->cast<GlobalStoreStmt>()) {
if (auto ptr = global_store->ptr->cast<GlobalPtrStmt>()) {
for (auto &snode : ptr->snodes.data) {
meta.output_states.emplace(snode, AsyncState::Type::value);
meta.output_states.insert(
yuanming-hu marked this conversation as resolved.
Show resolved Hide resolved
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
}
Expand All @@ -166,7 +167,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
for (auto &snode : ptr->snodes.data) {
// input_state is already handled in
// get_meta_input_value_states().
meta.output_states.emplace(snode, AsyncState::Type::value);
meta.output_states.insert(
ir_bank->get_async_state(snode, AsyncState::Type::value));
}
}
}
Expand All @@ -182,12 +184,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
activates.insert(sn);
for (auto &child : sn->ch) {
TI_ASSERT(child->type == SNodeType::place);
meta.input_states.emplace(child.get(), AsyncState::Type::value);
meta.output_states.emplace(child.get(), AsyncState::Type::value);
meta.input_states.insert(
ir_bank->get_async_state(child.get(), AsyncState::Type::value));
meta.output_states.insert(
ir_bank->get_async_state(child.get(), AsyncState::Type::value));
}
} else if (snode_op->op_type == SNodeOpType::is_active ||
snode_op->op_type == SNodeOpType::length) {
meta.input_states.emplace(sn, AsyncState::Type::mask);
meta.input_states.insert(
ir_bank->get_async_state(sn, AsyncState::Type::mask));
} else {
TI_NOT_IMPLEMENTED
}
Expand All @@ -210,12 +215,13 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
}
}
if (stmt->is<GlobalTemporaryStmt>()) {
auto as = AsyncState(t.kernel);
auto as = ir_bank->get_async_state(t.kernel);
meta.input_states.insert(as);
meta.output_states.insert(as);
}
if (auto clear_list = stmt->cast<ClearListStmt>()) {
meta.output_states.emplace(clear_list->snode, AsyncState::Type::list);
meta.output_states.insert(
ir_bank->get_async_state(clear_list->snode, AsyncState::Type::list));
}
return false;
});
Expand All @@ -237,11 +243,15 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {

// Do not record dense SNodes' mask states.
if (s->need_activation()) {
meta.input_states.emplace(s, AsyncState::Type::mask);
meta.output_states.emplace(s, AsyncState::Type::mask);
meta.input_states.insert(
ir_bank->get_async_state(s, AsyncState::Type::mask));
meta.output_states.insert(
ir_bank->get_async_state(s, AsyncState::Type::mask));
if (is_gc_able(s->type)) {
meta.input_states.emplace(s, AsyncState::Type::allocator);
meta.output_states.emplace(s, AsyncState::Type::allocator);
meta.input_states.insert(
ir_bank->get_async_state(s, AsyncState::Type::allocator));
meta.output_states.insert(
ir_bank->get_async_state(s, AsyncState::Type::allocator));
}
}
s = s->parent;
Expand All @@ -268,7 +278,8 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
continue;
}
if (s->type == SNodeType::place) {
meta.output_states.emplace(s, AsyncState::Type::value);
meta.output_states.insert(
ir_bank->get_async_state(s, AsyncState::Type::value));
} else {
for (auto &child : s->ch) {
if (deactivates.count(child.get()) == 0) {
Expand All @@ -288,22 +299,31 @@ TaskMeta *get_task_meta(IRBank *ir_bank, const TaskLaunchRecord &t) {
if (root_stmt->task_type == OffloadedTaskType::listgen) {
TI_ASSERT(root_stmt->snode->parent);
meta.snode = root_stmt->snode;
meta.input_states.emplace(root_stmt->snode->parent, AsyncState::Type::list);
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list);
meta.input_states.insert(ir_bank->get_async_state(root_stmt->snode->parent,
AsyncState::Type::list));
meta.input_states.insert(
ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::list));
if (root_stmt->snode->need_activation()) {
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::mask);
meta.input_states.insert(
ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::mask));
}
meta.output_states.emplace(root_stmt->snode, AsyncState::Type::list);
meta.output_states.insert(
ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::list));
} else if (root_stmt->task_type == OffloadedTaskType::struct_for) {
meta.snode = root_stmt->snode;
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::list);
meta.input_states.insert(
ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::list));
} else if ((root_stmt->task_type == OffloadedTaskType::gc) &&
(is_gc_able(root_stmt->snode->type))) {
meta.snode = root_stmt->snode;
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::mask);
meta.input_states.emplace(root_stmt->snode, AsyncState::Type::allocator);
meta.output_states.emplace(root_stmt->snode, AsyncState::Type::mask);
meta.output_states.emplace(root_stmt->snode, AsyncState::Type::allocator);
meta.input_states.insert(
ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::mask));
meta.input_states.insert(ir_bank->get_async_state(
root_stmt->snode, AsyncState::Type::allocator));
meta.output_states.insert(
ir_bank->get_async_state(root_stmt->snode, AsyncState::Type::mask));
meta.output_states.insert(ir_bank->get_async_state(
root_stmt->snode, AsyncState::Type::allocator));
insert_value_states_top_down(root_stmt->snode);
}

Expand Down
14 changes: 5 additions & 9 deletions taichi/program/async_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,17 +89,13 @@ struct AsyncState {
AsyncState() = default;

// For SNode
AsyncState(SNode *snode, Type type)
: snode_or_global_tmp(snode),
type(type),
unique_id(get_unique_id(snode, type)) {
AsyncState(SNode *snode, Type type, std::size_t unique_id)
: snode_or_global_tmp(snode), type(type), unique_id(unique_id) {
}

// For global temporaries
AsyncState(Kernel *kernel)
: snode_or_global_tmp(kernel),
type(Type::value),
unique_id(get_unique_id(kernel, type)) {
AsyncState(Kernel *kernel, std::size_t unique_id)
: snode_or_global_tmp(kernel), type(Type::value), unique_id(unique_id) {
}

bool operator<(const AsyncState &other) const {
Expand All @@ -124,7 +120,7 @@ struct AsyncState {
return std::get<SNode *>(snode_or_global_tmp);
}

static std::size_t get_unique_id(void *ptr, Type type);
static std::size_t perfect_hash(void *ptr, Type type);
};

struct TaskFusionMeta {
Expand Down
26 changes: 26 additions & 0 deletions taichi/program/ir_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/ir/transforms.h"
#include "taichi/ir/analysis.h"
#include "taichi/program/kernel.h"
#include "taichi/program/state_flow_graph.h"

TLANG_NAMESPACE_BEGIN

Expand Down Expand Up @@ -194,4 +195,29 @@ std::pair<IRHandle, bool> IRBank::optimize_dse(
return std::make_pair(ret_handle, false);
}

AsyncState IRBank::get_async_state(SNode *snode, AsyncState::Type type) {
auto id = lookup_async_state_id(snode, type);
sfg_->populate_latest_state_owner(id);
return AsyncState(snode, type, id);
}

AsyncState IRBank::get_async_state(Kernel *kernel) {
auto id = lookup_async_state_id(kernel, AsyncState::Type::value);
sfg_->populate_latest_state_owner(id);
return AsyncState(kernel, id);
}

void IRBank::set_sfg(StateFlowGraph *sfg) {
sfg_ = sfg;
}

std::size_t IRBank::lookup_async_state_id(void *ptr, AsyncState::Type type) {
auto h = AsyncState::perfect_hash(ptr, type);
if (async_state_to_unique_id_.find(h) == async_state_to_unique_id_.end()) {
async_state_to_unique_id_.insert(
std::make_pair(h, async_state_to_unique_id_.size()));
}
return async_state_to_unique_id_[h];
}

TLANG_NAMESPACE_END
13 changes: 13 additions & 0 deletions taichi/program/ir_bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

TLANG_NAMESPACE_BEGIN

class StateFlowGraph;

class IRBank {
public:
uint64 get_hash(IRNode *ir);
Expand Down Expand Up @@ -33,7 +35,14 @@ class IRBank {
std::unordered_map<IRHandle, TaskMeta> meta_bank_;
std::unordered_map<IRHandle, TaskFusionMeta> fusion_meta_bank_;

void set_sfg(StateFlowGraph *sfg);

AsyncState get_async_state(SNode *snode, AsyncState::Type type);

AsyncState get_async_state(Kernel *kernel);

private:
StateFlowGraph *sfg_;
std::unordered_map<IRNode *, uint64> hash_bank_;
std::unordered_map<IRHandle, std::unique_ptr<IRNode>> ir_bank_;
std::vector<std::unique_ptr<IRNode>> trash_bin; // prevent IR from deleted
Expand Down Expand Up @@ -72,8 +81,12 @@ class IRBank {
}
};
};

std::unordered_map<OptimizeDseKey, IRHandle, OptimizeDseKey::Hash>
optimize_dse_bank_;
std::unordered_map<std::size_t, std::size_t> async_state_to_unique_id_;

std::size_t lookup_async_state_id(void *ptr, AsyncState::Type type);
};

TLANG_NAMESPACE_END
Loading