Skip to content

Commit

Permalink
[ir] [bug] Make control-flow graph take function call into account (#…
Browse files Browse the repository at this point in the history
…2448)

* [ir] [bug] Make control-flow graph take function call into account

* format

* Add todo

* Use Python assertions

* Remove unused import
  • Loading branch information
xumingkuan authored Jun 22, 2021
1 parent 2daa097 commit 32732ac
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 22 deletions.
101 changes: 87 additions & 14 deletions taichi/analysis/build_cfg.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
#include "taichi/ir/control_flow_graph.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/program/function.h"

TLANG_NAMESPACE_BEGIN
namespace taichi {
namespace lang {

struct CFGFuncKey {
FunctionKey func_key{"", -1, -1};
bool in_parallel_for{false};

bool operator==(const CFGFuncKey &other_key) const {
return func_key == other_key.func_key &&
in_parallel_for == other_key.in_parallel_for;
}
};

} // namespace lang
} // namespace taichi

namespace std {
template <>
struct hash<taichi::lang::CFGFuncKey> {
std::size_t operator()(const taichi::lang::CFGFuncKey &key) const noexcept {
return std::hash<taichi::lang::FunctionKey>()(key.func_key) ^
((std::size_t)key.in_parallel_for << 32);
}
};
} // namespace std

namespace taichi {
namespace lang {

/**
* Build a control-flow graph. The resulting graph is guaranteed to have an
Expand All @@ -28,20 +56,10 @@ TLANG_NAMESPACE_BEGIN
*
* When there can be many CFGNodes in a Block, internal nodes are omitted for
* simplicity.
*
* TODO(#2193): Make sure ReturnStmt is handled properly.
*/
class CFGBuilder : public IRVisitor {
private:
std::unique_ptr<ControlFlowGraph> graph;
Block *current_block;
CFGNode *last_node_in_current_block;
std::vector<CFGNode *> continues_in_current_loop;
std::vector<CFGNode *> breaks_in_current_loop;
int current_stmt_id;
int begin_location;
std::vector<CFGNode *> prev_nodes;
OffloadedStmt *current_offload;
bool in_parallel_for;

public:
CFGBuilder()
: current_block(nullptr),
Expand Down Expand Up @@ -132,6 +150,46 @@ class CFGBuilder : public IRVisitor {
prev_nodes.push_back(node);
}

/**
* Structure:
*
* block {
* node {
* ...
* } -> node_func_begin;
* foo();
* (next node) {
* ...
* }
* }
*
* foo() {
* node_func_begin {
* ...
* } -> ... -> node_func_end;
* node_func_end {
* ...
* } -> (next node);
* }
*/
void visit(FuncCallStmt *stmt) override {
auto node_before_func_call = new_node(-1);
CFGFuncKey func_key = {stmt->func->func_key, in_parallel_for};
if (node_func_begin.count(func_key) == 0) {
// Generate CFG for the function.
TI_ASSERT(stmt->func->ir->is<Block>());
auto func_begin_index = graph->size();
stmt->func->ir->accept(this);
node_func_begin[func_key] = graph->nodes[func_begin_index].get();
node_func_end[func_key] = graph->nodes.back().get();
}
CFGNode::add_edge(node_before_func_call, node_func_begin[func_key]);
prev_nodes.push_back(node_func_end[func_key]);

// Don't put FuncCallStmt in any CFGNodes.
begin_location = current_stmt_id + 1;
}

/**
* Structure:
*
Expand Down Expand Up @@ -393,6 +451,20 @@ class CFGBuilder : public IRVisitor {
}
return std::move(builder.graph);
}

private:
std::unique_ptr<ControlFlowGraph> graph;
Block *current_block;
CFGNode *last_node_in_current_block;
std::vector<CFGNode *> continues_in_current_loop;
std::vector<CFGNode *> breaks_in_current_loop;
int current_stmt_id;
int begin_location;
std::vector<CFGNode *> prev_nodes;
OffloadedStmt *current_offload;
bool in_parallel_for;
std::unordered_map<CFGFuncKey, CFGNode *> node_func_begin;
std::unordered_map<CFGFuncKey, CFGNode *> node_func_end;
};

namespace irpass::analysis {
Expand All @@ -401,4 +473,5 @@ std::unique_ptr<ControlFlowGraph> build_cfg(IRNode *root) {
}
} // namespace irpass::analysis

TLANG_NAMESPACE_END
} // namespace lang
} // namespace taichi
21 changes: 13 additions & 8 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

import taichi as ti


Expand Down Expand Up @@ -150,11 +148,11 @@ def run(self) -> ti.i32:
assert x[None] == 0


@pytest.mark.skip(reason='https://github.com/taichi-dev/taichi/issues/2442')
@ti.test(experimental_real_function=True, debug=True)
@ti.test(experimental_real_function=True)
def test_templates():
x = ti.field(ti.i32, shape=())
y = ti.field(ti.i32, shape=())
answer = ti.field(ti.i32, shape=8)

@ti.kernel
def kernel_inc(x: ti.template()):
Expand All @@ -179,11 +177,18 @@ def run_func():
x[None] = 10
y[None] = 20
inc(x)
assert x[None] == 11
assert y[None] == 20
answer[0] = x[None]
answer[1] = y[None]
inc(y)
assert x[None] == 11
assert y[None] == 21
answer[2] = x[None]
answer[3] = y[None]

def verify():
assert answer[0] == 11
assert answer[1] == 20
assert answer[2] == 11
assert answer[3] == 21

run_kernel()
run_func()
verify()

0 comments on commit 32732ac

Please sign in to comment.