Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] [bug] Make control-flow graph take function call into account #2448

Merged
merged 5 commits into from
Jun 22, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 87 additions & 14 deletions taichi/analysis/build_cfg.cpp
Original file line number Diff line number Diff line change
@@ -1,8 +1,36 @@
#include "taichi/ir/control_flow_graph.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/statements.h"
#include "taichi/program/function.h"

TLANG_NAMESPACE_BEGIN
namespace taichi {
namespace lang {

struct CFGFuncKey {
FunctionKey func_key{"", -1, -1};
bool in_parallel_for{false};

bool operator==(const CFGFuncKey &other_key) const {
return func_key == other_key.func_key &&
in_parallel_for == other_key.in_parallel_for;
}
};

} // namespace lang
} // namespace taichi

namespace std {
template <>
struct hash<taichi::lang::CFGFuncKey> {
std::size_t operator()(const taichi::lang::CFGFuncKey &key) const noexcept {
return std::hash<taichi::lang::FunctionKey>()(key.func_key) ^
((std::size_t)key.in_parallel_for << 32);
}
};
} // namespace std

namespace taichi {
namespace lang {

/**
* Build a control-flow graph. The resulting graph is guaranteed to have an
Expand All @@ -28,20 +56,10 @@ TLANG_NAMESPACE_BEGIN
*
* When there can be many CFGNodes in a Block, internal nodes are omitted for
* simplicity.
*
* TODO(#2193): Make sure ReturnStmt is handled properly.
*/
class CFGBuilder : public IRVisitor {
private:
std::unique_ptr<ControlFlowGraph> graph;
Block *current_block;
CFGNode *last_node_in_current_block;
std::vector<CFGNode *> continues_in_current_loop;
std::vector<CFGNode *> breaks_in_current_loop;
int current_stmt_id;
int begin_location;
std::vector<CFGNode *> prev_nodes;
OffloadedStmt *current_offload;
bool in_parallel_for;

public:
CFGBuilder()
: current_block(nullptr),
Expand Down Expand Up @@ -132,6 +150,46 @@ class CFGBuilder : public IRVisitor {
prev_nodes.push_back(node);
}

/**
* Structure:
*
* block {
* node {
* ...
* } -> node_func_begin;
* foo();
* (next node) {
* ...
* }
* }
*
* foo() {
* node_func_begin {
* ...
* } -> ... -> node_func_end;
* node_func_end {
* ...
* } -> (next node);
* }
*/
void visit(FuncCallStmt *stmt) override {
auto node_before_func_call = new_node(-1);
CFGFuncKey func_key = {stmt->func->func_key, in_parallel_for};
if (node_func_begin.count(func_key) == 0) {
// Generate CFG for the function.
TI_ASSERT(stmt->func->ir->is<Block>());
auto func_begin_index = graph->size();
stmt->func->ir->accept(this);
node_func_begin[func_key] = graph->nodes[func_begin_index].get();
node_func_end[func_key] = graph->nodes.back().get();
}
CFGNode::add_edge(node_before_func_call, node_func_begin[func_key]);
prev_nodes.push_back(node_func_end[func_key]);

// Don't put FuncCallStmt in any CFGNodes.
begin_location = current_stmt_id + 1;
}

/**
* Structure:
*
Expand Down Expand Up @@ -393,6 +451,20 @@ class CFGBuilder : public IRVisitor {
}
return std::move(builder.graph);
}

private:
std::unique_ptr<ControlFlowGraph> graph;
Block *current_block;
CFGNode *last_node_in_current_block;
std::vector<CFGNode *> continues_in_current_loop;
std::vector<CFGNode *> breaks_in_current_loop;
int current_stmt_id;
int begin_location;
std::vector<CFGNode *> prev_nodes;
OffloadedStmt *current_offload;
bool in_parallel_for;
std::unordered_map<CFGFuncKey, CFGNode *> node_func_begin;
std::unordered_map<CFGFuncKey, CFGNode *> node_func_end;
};

namespace irpass::analysis {
Expand All @@ -401,4 +473,5 @@ std::unique_ptr<ControlFlowGraph> build_cfg(IRNode *root) {
}
} // namespace irpass::analysis

TLANG_NAMESPACE_END
} // namespace lang
} // namespace taichi
21 changes: 13 additions & 8 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

import taichi as ti


Expand Down Expand Up @@ -150,11 +148,11 @@ def run(self) -> ti.i32:
assert x[None] == 0


@pytest.mark.skip(reason='https://github.com/taichi-dev/taichi/issues/2442')
@ti.test(experimental_real_function=True, debug=True)
@ti.test(experimental_real_function=True)
def test_templates():
xumingkuan marked this conversation as resolved.
Show resolved Hide resolved
x = ti.field(ti.i32, shape=())
y = ti.field(ti.i32, shape=())
answer = ti.field(ti.i32, shape=8)

@ti.kernel
def kernel_inc(x: ti.template()):
Expand All @@ -179,11 +177,18 @@ def run_func():
x[None] = 10
y[None] = 20
inc(x)
assert x[None] == 11
assert y[None] == 20
answer[0] = x[None]
answer[1] = y[None]
inc(y)
assert x[None] == 11
assert y[None] == 21
answer[2] = x[None]
answer[3] = y[None]

def verify():
assert answer[0] == 11
assert answer[1] == 20
assert answer[2] == 11
assert answer[3] == 21

run_kernel()
run_func()
verify()