Skip to content

Commit

Permalink
[opt] Treat FuncCallStmt better in store-to-load forwarding in CFG (#…
Browse files Browse the repository at this point in the history
…8155)

Issue: #602
Pass `gather_func_store_dests` gathers all destinations whose content
may change after a real function is called. The change may happen in the
real function or in another real function that the real function calls.
This pass uses Tarjan's strongly connected components algorithm to find
the store destinations for all real functions a kernel calls, and store
them in `store_dests` of the respective function.

The global pointers are lowered in `lower_access`, so we need to gather
the store destinations twice: before and after pass `lower_access`.

<!--
copilot:all
-->
### <samp>🤖 Generated by Copilot at 2c5586e</samp>

### Summary
📝🛠️🚀

<!--
1. 📝 This emoji represents the addition of a new file and a new analysis
pass declaration, which are documentation-related changes.
2. 🛠️ This emoji represents the update of the `ControlFlowGraph` class
and the removal of some redundant or incorrect checks, which are
bug-fixing or improvement-related changes.
3. 🚀 This emoji represents the introduction of a new enum type, a new
method, and a new parameter, which are feature-related changes.
-->
This pull request introduces a new analysis pass
`gather_func_store_dests` that can handle function calls in the IR and
optimize their memory access and aliasing. It updates the `Function`,
`FuncCallStmt`, and `ControlFlowGraph` classes and the
`compile_function` and `compile_taichi_functions` transforms to use a
new enum type `IRStage` and a new parameter `target_stage` to track and
control the IR stage of each function. It also modifies some existing
analysis functions and adds some include directives and forward
declarations to support the new pass.

> _To optimize function calls in the IR_
> _We need a new pass to infer_
> _The store destinations_
> _At different stages_
> _And use `IRStage` instead of `IRType` for sure_

### Walkthrough
* Add a new analysis pass `gather_func_store_dests` to collect the store
destinations of each function in the IR
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-0bfbe49ff08844a76d5d2e1c5b81c2cf813be4a9089422b997bc380ec9a68eadR1-R103),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-f6bc75768d2e24c782fefa45a7232d0e2b2bae091e697040e7f442a77d80ad45L216-R216))
* Modify the `FuncCallStmt` class to inherit from the `Store` trait and
implement the `get_store_destination` method, using the arguments of the
function call and the `store_dests` set of the called function
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R277-R289),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L1062-R1062),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260R1074-R1080))
* Remove or modify the checks for `FuncCallStmt` in the
`ControlFlowGraph` class, and use the `store_dests` set of the called
function to update the reaching definition analysis
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL164-L167),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL219-R216),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR695),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL982-R977),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR988-R990))
* Add a new member variable `func_store_dests` to the `ControlFlowGraph`
class, which is a map from `Function` pointers to sets of `Stmt`
pointers, representing the store destinations of each function in the IR
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-67e7205404aa056a1553f930af38b359e460f98a4ec335faec7d54aaf9df727fR117-R118))
* Replace the old enum type `IRType` with the new enum type `IRStage`,
which has more values to indicate different IR stages of function
compilation
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-aa860f71a793b08676a24cab247b43f5ed8d105a6493eeb1a035369b916bddc2L17-R17),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-aa860f71a793b08676a24cab247b43f5ed8d105a6493eeb1a035369b916bddc2L32-R32),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-af3316673541832f351d12d7c2f45b3c49ba5caeafdad3a6356cb13d2524be3dL9-R20),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-af3316673541832f351d12d7c2f45b3c49ba5caeafdad3a6356cb13d2524be3dL31-R50),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-f78d8ce92dcf8a10d2a446d35cc26f47fd2a42314b0799d263196b6eb858fe76L13-R33),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-f78d8ce92dcf8a10d2a446d35cc26f47fd2a42314b0799d263196b6eb858fe76L39-R48),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL330-R390))
* Modify the signature of the `compile_function` function to use the new
parameter `target_stage` instead of the old parameter `start_from_ast`,
to indicate the desired IR stage of the function compilation
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L199-R200),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL330-R390))
* Modify the definition of the `compile_to_offloads` function to add two
calls to the new analysis pass `gather_func_store_dests`, before and
after the call to the `compile_taichi_functions` function, and to pass
different `target_stage` parameters to the `compile_taichi_functions`
function
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL47-R51))
* Add or modify the include directives and forward declarations for the
header files `function.h`, `statements.h`, and `unordered_set` in the
source files and header files that use the `Function` class, the
`FuncCallStmt` class, or the `std::unordered_set` container
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR9),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-67e7205404aa056a1553f930af38b359e460f98a4ec335faec7d54aaf9df727fR10),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R5),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934R20),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-af3316673541832f351d12d7c2f45b3c49ba5caeafdad3a6356cb13d2524be3dR3))
* Modify some comments in the header file `transforms.h` to remove the
mentions of not demoting dense struct fors or reducing the number of
statements before inlining, since these are no longer relevant or
necessary after the new analysis pass
([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L160-R161),
[link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L192-R190))
  • Loading branch information
lin-hitonami authored Jun 12, 2023
1 parent 294a189 commit d24c7c7
Show file tree
Hide file tree
Showing 11 changed files with 223 additions and 74 deletions.
103 changes: 103 additions & 0 deletions taichi/analysis/gather_func_store_dests.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
#include <stack>
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/ir/visitors.h"
#include "taichi/ir/control_flow_graph.h"
#include "taichi/program/function.h"

namespace taichi::lang {

class GatherFuncStoreDests : public BasicStmtVisitor {
private:
std::unordered_set<Stmt *> results_;
Function *current_func_;
struct TarjanData {
std::unordered_map<Function *, int> func_dfn;
std::unordered_map<Function *, int> func_low;
std::unordered_set<Function *> func_in_stack;
std::stack<Function *> func_stack;
};
TarjanData &tarjan_data_;

static std::unordered_set<Stmt *> run(Function *func,
TarjanData &tarjan_data) {
TI_ASSERT(tarjan_data.func_dfn.count(func) == 0);
tarjan_data.func_dfn[func] = tarjan_data.func_low[func] =
tarjan_data.func_dfn.size();
tarjan_data.func_in_stack.insert(func);
tarjan_data.func_stack.push(func);
GatherFuncStoreDests searcher(func, tarjan_data);
func->ir->accept(&searcher);
if (tarjan_data.func_low[func] == tarjan_data.func_dfn[func]) {
while (true) {
auto top = tarjan_data.func_stack.top();
tarjan_data.func_stack.pop();
tarjan_data.func_in_stack.erase(top);
top->store_dests.insert(searcher.results_.begin(),
searcher.results_.end());
if (top == func) {
break;
}
}
}
return searcher.results_;
}

static void run(IRNode *ir, TarjanData &tarjan_data) {
GatherFuncStoreDests searcher(nullptr, tarjan_data);
ir->accept(&searcher);
}

public:
using BasicStmtVisitor::visit;

GatherFuncStoreDests(Function *func, TarjanData &tarjan_data)
: current_func_(func), tarjan_data_(tarjan_data) {
allow_undefined_visitor = true;
invoke_default_visitor = true;
}

void visit(Stmt *stmt) override {
if (!current_func_) {
return;
}
auto result = irpass::analysis::get_store_destination(stmt);
results_.insert(result.begin(), result.end());
}

void visit(FuncCallStmt *stmt) override {
auto func = stmt->func;
if (!current_func_) {
if (!tarjan_data_.func_dfn.count(func)) {
run(func, tarjan_data_);
}
return;
}
if (!tarjan_data_.func_dfn.count(func)) {
auto result = run(func, tarjan_data_);
results_.merge(result);
tarjan_data_.func_low[current_func_] = std::min(
tarjan_data_.func_low[current_func_], tarjan_data_.func_low[func]);
} else if (tarjan_data_.func_in_stack.count(func)) {
tarjan_data_.func_low[current_func_] = std::min(
tarjan_data_.func_low[current_func_], tarjan_data_.func_dfn[func]);
} else {
const auto &dests = func->store_dests;
results_.insert(dests.begin(), dests.end());
}
}

static void run(IRNode *ir) {
TarjanData tarjan_data;
run(ir, tarjan_data);
}
};

namespace irpass::analysis {
void gather_func_store_dests(IRNode *ir) {
GatherFuncStoreDests::run(ir);
}

} // namespace irpass::analysis

} // namespace taichi::lang
2 changes: 1 addition & 1 deletion taichi/ir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,7 @@ std::unique_ptr<MeshBLSCaches> initialize_mesh_local_attribute(
OffloadedStmt *offload,
bool auto_mesh_local,
const CompileConfig &config);

void gather_func_store_dests(IRNode *ir);
} // namespace analysis
} // namespace irpass
} // namespace taichi::lang
17 changes: 7 additions & 10 deletions taichi/ir/control_flow_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/ir/analysis.h"
#include "taichi/ir/statements.h"
#include "taichi/system/profiler.h"
#include "taichi/program/function.h"

namespace taichi::lang {

Expand Down Expand Up @@ -161,10 +162,6 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
// [Intra-block Search]
int last_def_position = -1;
for (int i = position - 1; i >= begin_location; i--) {
if (block->statements[i]->is<FuncCallStmt>()) {
return nullptr;
}

// Find previous store stmt to the same dest_addr, stop at the closest one.
// store_ptr: prev-store dest_addr
for (auto store_ptr :
Expand Down Expand Up @@ -216,10 +213,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const {
}

// Check if store_stmt will ever influence the value of var
auto may_contain_address = [](Stmt *store_stmt, Stmt *var) {
if (store_stmt->is<FuncCallStmt>()) {
return true;
}
auto may_contain_address = [&](Stmt *store_stmt, Stmt *var) {
for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) {
if (var->is<MatrixPtrStmt>() && !store_ptr->is<MatrixPtrStmt>()) {
// check for aliased address with var
Expand Down Expand Up @@ -698,6 +692,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) {
if (stmt->is<FuncCallStmt>()) {
killed_in_this_node.clear();
live_load_in_this_node.clear();
continue;
}
auto store_ptrs = irpass::analysis::get_store_destination(stmt);

Expand Down Expand Up @@ -979,8 +974,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
for (int i = 0; i < num_nodes; i++) {
for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) {
auto stmt = nodes[i]->block->statements[j].get();
if (stmt->is<FuncCallStmt>() ||
(stmt->is<MatrixPtrStmt>() &&
if ((stmt->is<MatrixPtrStmt>() &&
stmt->as<MatrixPtrStmt>()->origin->is<AllocaStmt>()) ||
(!after_lower_access &&
(stmt->is<GlobalPtrStmt>() || stmt->is<ExternalPtrStmt>() ||
Expand All @@ -991,6 +985,9 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) {
// TODO: unify them
// A global pointer that may contain some data before this kernel.
nodes[start_node]->reach_gen.insert(stmt);
} else if (auto func_call = stmt->cast<FuncCallStmt>()) {
const auto &dests = func_call->func->store_dests;
nodes[start_node]->reach_gen.insert(dests.begin(), dests.end());
}
}
}
Expand Down
3 changes: 3 additions & 0 deletions taichi/ir/control_flow_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

namespace taichi::lang {

class Function;
/**
* A basic block in control-flow graph.
* A CFGNode contains a reference to a part of the CHI IR, or more precisely,
Expand Down Expand Up @@ -113,6 +114,8 @@ class ControlFlowGraph {
const int start_node = 0;
int final_node{0};

std::unordered_map<Function *, std::unordered_set<Stmt *>> func_store_dests;

template <typename... Args>
CFGNode *push_back(Args &&...args) {
nodes.emplace_back(std::make_unique<CFGNode>(std::forward<Args>(args)...));
Expand Down
14 changes: 14 additions & 0 deletions taichi/ir/statements.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "taichi/ir/statements.h"
#include "taichi/util/bit.h"
#include "taichi/program/kernel.h"
#include "taichi/program/function.h"

namespace taichi::lang {

Expand Down Expand Up @@ -278,6 +279,19 @@ FuncCallStmt::FuncCallStmt(Function *func, const std::vector<Stmt *> &args)
TI_STMT_REG_FIELDS;
}

stmt_refs FuncCallStmt::get_store_destination() const {
std::vector<Stmt *> ret;
for (auto &arg : args) {
if (auto ref = arg->cast<ReferenceStmt>()) {
ret.push_back(ref->var);
} else if (arg->ret_type.is_pointer()) {
ret.push_back(arg);
}
}
ret.insert(ret.end(), func->store_dests.begin(), func->store_dests.end());
return ret;
}

WhileStmt::WhileStmt(std::unique_ptr<Block> &&body)
: mask(nullptr), body(std::move(body)) {
this->body->set_parent_stmt(this);
Expand Down
9 changes: 8 additions & 1 deletion taichi/ir/statements.h
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,7 @@ class MeshForStmt : public Stmt {
/**
* Call an inline Taichi function.
*/
class FuncCallStmt : public Stmt {
class FuncCallStmt : public Stmt, public ir_traits::Store {
public:
Function *func;
std::vector<Stmt *> args;
Expand All @@ -1074,6 +1074,13 @@ class FuncCallStmt : public Stmt {
return global_side_effect;
}

// IR Trait: Store
stmt_refs get_store_destination() const override;

Stmt *get_store_data() const override {
return nullptr;
}

TI_STMT_DEF_FIELDS(ret_type, func, args);
TI_DEFINE_ACCEPT_AND_CLONE
};
Expand Down
15 changes: 7 additions & 8 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "taichi/transforms/demote_mesh_statements.h"
#include "taichi/transforms/simplify.h"
#include "taichi/common/trait.h"
#include "taichi/program/function.h"

namespace taichi::lang {

Expand Down Expand Up @@ -158,10 +159,7 @@ std::unordered_map<int, ExternalPtrAccess> detect_external_ptr_access_in_task(
OffloadedStmt *offload);

// compile_to_offloads does the basic compilation to create all the offloaded
// tasks of a Taichi kernel. It's worth pointing out that this doesn't demote
// dense struct fors. This is a necessary workaround to prevent the async
// engine from fusing incompatible offloaded tasks. TODO(Lin): check this
// comment
// tasks of a Taichi kernel.
void compile_to_offloads(IRNode *ir,
const CompileConfig &config,
const Kernel *kernel,
Expand Down Expand Up @@ -190,16 +188,17 @@ void compile_to_executable(IRNode *ir,
bool make_thread_local = false,
bool make_block_local = false,
bool start_from_ast = true);
// Compile a function with some basic optimizations, so that the number of
// statements is reduced before inlining.
// Compile a function with some basic optimizations
void compile_function(IRNode *ir,
const CompileConfig &config,
Function *func,
AutodiffMode autodiff_mode,
bool verbose,
bool start_from_ast);
Function::IRStage target_stage);

void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config);
void compile_taichi_functions(IRNode *ir,
const CompileConfig &compile_config,
Function::IRStage target_stage);
} // namespace irpass

} // namespace taichi::lang
4 changes: 2 additions & 2 deletions taichi/program/function.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Function::Function(Program *program, const FunctionKey &func_key)
void Function::set_function_body(const std::function<void()> &func) {
context = std::make_unique<FrontendContext>(program->compile_config().arch);
ir = context->get_root();
ir_type_ = IRType::AST;
ir_stage_ = IRStage::AST;

func();
finalize_params();
Expand All @@ -29,7 +29,7 @@ void Function::set_function_body(const std::function<void()> &func) {

void Function::set_function_body(std::unique_ptr<IRNode> func_body) {
ir = std::move(func_body);
ir_type_ = IRType::InitialIR;
ir_stage_ = IRStage::InitialIR;
}

std::string Function::get_name() const {
Expand Down
22 changes: 16 additions & 6 deletions taichi/program/function.h
Original file line number Diff line number Diff line change
@@ -1,15 +1,23 @@
#pragma once

#include <unordered_set>
#include "taichi/program/callable.h"
#include "taichi/program/function_key.h"

namespace taichi::lang {

class Program;
class Stmt;

class Function : public Callable {
public:
enum class IRType { None, AST, InitialIR, OptimizedIR };
enum class IRStage : int {
None = 0,
AST = 1,
InitialIR = 2,
BeforeLowerAccess = 3,
OptimizedIR = 4
};

FunctionKey func_key;

Expand All @@ -28,16 +36,18 @@ class Function : public Callable {
return ast_serialization_data_;
}

void set_ir_type(IRType type) {
ir_type_ = type;
void set_ir_stage(IRStage type) {
ir_stage_ = type;
}

IRType ir_type() const {
return ir_type_;
IRStage ir_stage() const {
return ir_stage_;
}

std::unordered_set<Stmt *> store_dests;

private:
IRType ir_type_{IRType::None};
IRStage ir_stage_{IRStage::None};
std::optional<std::string> ast_serialization_data_; // For generating AST-Key
};

Expand Down
Loading

0 comments on commit d24c7c7

Please sign in to comment.