Skip to content

Commit

Permalink
[Opt] [bug] Better aliasing analysis for dead store elimination (#1432)
Browse files Browse the repository at this point in the history
* [Opt] [bug] Better aliasing analysis for dead store elimination

* [skip ci] enforce code format

Co-authored-by: Taichi Gardener <[email protected]>
  • Loading branch information
xumingkuan and taichi-gardener authored Jul 7, 2020
1 parent c75e5ed commit 8f5c84d
Show file tree
Hide file tree
Showing 6 changed files with 109 additions and 15 deletions.
6 changes: 6 additions & 0 deletions taichi/analysis/data_source_analysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ std::vector<Stmt *> get_load_pointers(Stmt *load_stmt) {
} else if (auto stack_acc_adj = load_stmt->cast<StackAccAdjointStmt>()) {
// This statement loads and stores the adjoint data.
return std::vector<Stmt *>(1, stack_acc_adj->stack);
} else if (auto stack_push = load_stmt->cast<StackPushStmt>()) {
// This is to make dead store elimination not eliminate consequent pushes.
return std::vector<Stmt *>(1, stack_push->stack);
} else if (auto stack_pop = load_stmt->cast<StackPopStmt>()) {
// This is to make dead store elimination not eliminate consequent pops.
return std::vector<Stmt *>(1, stack_pop->stack);
} else {
return std::vector<Stmt *>();
}
Expand Down
41 changes: 32 additions & 9 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,8 +82,27 @@ bool CFGNode::contain_variable(const std::unordered_set<Stmt *> &var_set,
return var_set.find(var) != var_set.end();
} else {
// TODO: How to optimize this?
if (var_set.find(var) != var_set.end())
return true;
for (auto set_var : var_set) {
if (irpass::analysis::same_statements(var, set_var)) {
if (definitely_same_address(var, set_var)) {
return true;
}
}
return false;
}
}

bool CFGNode::may_contain_variable(const std::unordered_set<Stmt *> &var_set,
Stmt *var) {
if (var->is<AllocaStmt>() || var->is<StackAllocaStmt>()) {
return var_set.find(var) != var_set.end();
} else {
// TODO: How to optimize this?
if (var_set.find(var) != var_set.end())
return true;
for (auto set_var : var_set) {
if (maybe_same_address(var, set_var)) {
return true;
}
}
Expand Down Expand Up @@ -290,16 +309,18 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
store_ptr = stack_push->stack;
} else if (auto stack_acc_adj = stmt->cast<StackAccAdjointStmt>()) {
store_ptr = stack_acc_adj->stack;
} else if (stmt->is<StackAllocaStmt>()) {
store_ptr = stmt;
}
if (store_ptr) {
if (!after_lower_access ||
(store_ptr->is<AllocaStmt>() || store_ptr->is<StackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
// Do not eliminate AllocaStmt here.
if (!stmt->is<AllocaStmt>() &&
// Do not eliminate AllocaStmt and StackAllocaStmt here.
if (!stmt->is<AllocaStmt>() && !stmt->is<StackAllocaStmt>() &&
!may_contain_variable(live_in_this_node, store_ptr) &&
(contain_variable(killed_in_this_node, store_ptr) ||
(!contain_variable(live_out, store_ptr) &&
!contain_variable(live_in_this_node, store_ptr)))) {
!may_contain_variable(live_out, store_ptr))) {
// Neither used in other nodes nor used in this node.
if (auto atomic = stmt->cast<AtomicOpStmt>()) {
// Weaken the atomic operation to a load.
Expand All @@ -309,7 +330,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
local_load->ret_type = atomic->ret_type;
replace_with(i, std::move(local_load), true);
// Notice that we have a load here.
killed_in_this_node.erase(atomic->dest);
live_in_this_node.insert(atomic->dest);
modified = true;
continue;
Expand All @@ -322,7 +342,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
global_load->ret_type = atomic->ret_type;
replace_with(i, std::move(global_load), true);
// Notice that we have a load here.
killed_in_this_node.erase(atomic->dest);
live_in_this_node.insert(atomic->dest);
modified = true;
continue;
Expand All @@ -335,7 +354,12 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
} else {
// A non-eliminated store.
killed_in_this_node.insert(store_ptr);
live_in_this_node.erase(store_ptr);
auto old_live_in_this_node = std::move(live_in_this_node);
live_in_this_node.clear();
for (auto &var : old_live_in_this_node) {
if (!definitely_same_address(store_ptr, var))
live_in_this_node.insert(var);
}
}
}
}
Expand All @@ -344,7 +368,6 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (!after_lower_access ||
(load_ptr->is<AllocaStmt>() || load_ptr->is<StackAllocaStmt>())) {
// After lower_access, we only analyze local variables and stacks.
killed_in_this_node.erase(load_ptr);
live_in_this_node.insert(load_ptr);
}
}
Expand Down
2 changes: 2 additions & 0 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class CFGNode {

static bool contain_variable(const std::unordered_set<Stmt *> &var_set,
Stmt *var);
static bool may_contain_variable(const std::unordered_set<Stmt *> &var_set,
Stmt *var);
void reaching_definition_analysis(bool after_lower_access);
bool reach_kill_variable(Stmt *var) const;
Stmt *get_store_forwarding_data(Stmt *var, int position) const;
Expand Down
62 changes: 62 additions & 0 deletions taichi/ir/ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,51 @@ CompileConfig &IRNode::get_config() const {
return get_kernel()->program.config;
}

bool definitely_same_address(Stmt *var1, Stmt *var2) {
// Return true when two statements must be the same address;
// false when two statements can be different addresses.

// If both stmts are allocas, they have the same address iff var1 == var2.
// If only one of them is an alloca, they can never share the same address.
if (var1 == var2)
return true;
if (!var1 || !var2)
return false;
if (var1->is<AllocaStmt>() || var2->is<AllocaStmt>())
return false;
if (var1->is<StackAllocaStmt>() || var2->is<StackAllocaStmt>())
return false;

// TODO(xumingkuan): Put GlobalTemporaryStmt, ThreadLocalPtrStmt and
// BlockLocalPtrStmt into GlobalPtrStmt.
// If both statements are global temps, they have the same address iff they
// have the same offset. If only one of them is a global temp, they can never
// share the same address.
if (var1->is<GlobalTemporaryStmt>() || var2->is<GlobalTemporaryStmt>()) {
if (!var1->is<GlobalTemporaryStmt>() || !var2->is<GlobalTemporaryStmt>())
return false;
return var1->as<GlobalTemporaryStmt>()->offset ==
var2->as<GlobalTemporaryStmt>()->offset;
}

if (var1->is<ThreadLocalPtrStmt>() || var2->is<ThreadLocalPtrStmt>()) {
if (!var1->is<ThreadLocalPtrStmt>() || !var2->is<ThreadLocalPtrStmt>())
return false;
return var1->as<ThreadLocalPtrStmt>()->offset ==
var2->as<ThreadLocalPtrStmt>()->offset;
}

if (var1->is<BlockLocalPtrStmt>() || var2->is<BlockLocalPtrStmt>()) {
if (!var1->is<BlockLocalPtrStmt>() || !var2->is<BlockLocalPtrStmt>())
return false;
return irpass::analysis::same_statements(
var1->as<BlockLocalPtrStmt>()->offset,
var2->as<BlockLocalPtrStmt>()->offset);
}

return irpass::analysis::same_statements(var1, var2);
}

bool maybe_same_address(Stmt *var1, Stmt *var2) {
// Return true when two statements might be the same address;
// false when two statements cannot be the same address.
Expand All @@ -36,6 +81,8 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) {
return false;
if (var1->is<AllocaStmt>() || var2->is<AllocaStmt>())
return false;
if (var1->is<StackAllocaStmt>() || var2->is<StackAllocaStmt>())
return false;

// If both statements are global temps, they have the same address iff they
// have the same offset. If only one of them is a global temp, they can never
Expand All @@ -47,6 +94,21 @@ bool maybe_same_address(Stmt *var1, Stmt *var2) {
var2->as<GlobalTemporaryStmt>()->offset;
}

if (var1->is<ThreadLocalPtrStmt>() || var2->is<ThreadLocalPtrStmt>()) {
if (!var1->is<ThreadLocalPtrStmt>() || !var2->is<ThreadLocalPtrStmt>())
return false;
return var1->as<ThreadLocalPtrStmt>()->offset ==
var2->as<ThreadLocalPtrStmt>()->offset;
}

if (var1->is<BlockLocalPtrStmt>() || var2->is<BlockLocalPtrStmt>()) {
if (!var1->is<BlockLocalPtrStmt>() || !var2->is<BlockLocalPtrStmt>())
return false;
return irpass::analysis::same_statements(
var1->as<BlockLocalPtrStmt>()->offset,
var2->as<BlockLocalPtrStmt>()->offset);
}

// If both statements are GlobalPtrStmts or GetChStmts, we can check by
// SNode::id.
TI_ASSERT(var1->width() == 1);
Expand Down
1 change: 1 addition & 0 deletions taichi/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ using ScratchPadOptions = std::vector<std::pair<int, SNode *>>;

IRBuilder &current_ast_builder();

bool definitely_same_address(Stmt *var1, Stmt *var2);
bool maybe_same_address(Stmt *var1, Stmt *var2);

struct VectorType {
Expand Down
12 changes: 6 additions & 6 deletions tests/python/test_ad_for.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_ad_sum():
p = ti.var(ti.f32, shape=N, needs_grad=True)

@ti.kernel
def comptue_sum():
def compute_sum():
for i in range(N):
ret = 1.0
for j in range(b[i]):
Expand All @@ -21,13 +21,13 @@ def comptue_sum():
a[i] = 3
b[i] = i

comptue_sum()
compute_sum()

for i in range(N):
assert p[i] == 3 * b[i] + 1
p.grad[i] = 1

comptue_sum.grad()
compute_sum.grad()

for i in range(N):
assert a.grad[i] == b[i]
Expand All @@ -43,7 +43,7 @@ def test_ad_sum_local_atomic():
p = ti.var(ti.f32, shape=N, needs_grad=True)

@ti.kernel
def comptue_sum():
def compute_sum():
for i in range(N):
ret = 1.0
for j in range(b[i]):
Expand All @@ -54,13 +54,13 @@ def comptue_sum():
a[i] = 3
b[i] = i

comptue_sum()
compute_sum()

for i in range(N):
assert p[i] == 3 * b[i] + 1
p.grad[i] = 1

comptue_sum.grad()
compute_sum.grad()

for i in range(N):
assert a.grad[i] == b[i]
Expand Down

0 comments on commit 8f5c84d

Please sign in to comment.