From 9444de6074d55d737c0b887c27c86c621711ca06 Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Mon, 12 Jun 2023 11:04:28 +0800 Subject: [PATCH] [opt] Treat FuncCallStmt better in store-to-load forwarding in CFG (#8155) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: #602 Pass `gather_func_store_dests` gathers all destinations whose content may change after a real function is called. The change may happen in the real function or in another real function that the real function calls. This pass uses Tarjan's strongly connected components algorithm to find the store destinations for all real functions a kernel calls, and store them in `store_dests` of the respective function. The global pointers are lowered in `lower_access`, so we need to gather the store destinations twice: before and after pass `lower_access`. ### 🤖 Generated by Copilot at 2c5586e ### Summary 📝🛠️🚀 This pull request introduces a new analysis pass `gather_func_store_dests` that can handle function calls in the IR and optimize their memory access and aliasing. It updates the `Function`, `FuncCallStmt`, and `ControlFlowGraph` classes and the `compile_function` and `compile_taichi_functions` transforms to use a new enum type `IRStage` and a new parameter `target_stage` to track and control the IR stage of each function. It also modifies some existing analysis functions and adds some include directives and forward declarations to support the new pass. > _To optimize function calls in the IR_ > _We need a new pass to infer_ > _The store destinations_ > _At different stages_ > _And use `IRStage` instead of `IRType` for sure_ ### Walkthrough * Add a new analysis pass `gather_func_store_dests` to collect the store destinations of each function in the IR ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-0bfbe49ff08844a76d5d2e1c5b81c2cf813be4a9089422b997bc380ec9a68eadR1-R103), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-f6bc75768d2e24c782fefa45a7232d0e2b2bae091e697040e7f442a77d80ad45L216-R216)) * Modify the `FuncCallStmt` class to inherit from the `Store` trait and implement the `get_store_destination` method, using the arguments of the function call and the `store_dests` set of the called function ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R277-R289), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260L1062-R1062), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-917d9436dcaafa0f1e41ae9bad90273a303f036f00da94e417788a7fa1dc5260R1074-R1080)) * Remove or modify the checks for `FuncCallStmt` in the `ControlFlowGraph` class, and use the `store_dests` set of the called function to update the reaching definition analysis ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL164-L167), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL219-R216), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR695), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fL982-R977), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR988-R990)) * Add a new member variable `func_store_dests` to the `ControlFlowGraph` class, which is a map from `Function` pointers to sets of `Stmt` pointers, representing the store destinations of each function in the IR ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-67e7205404aa056a1553f930af38b359e460f98a4ec335faec7d54aaf9df727fR117-R118)) * Replace the old enum type `IRType` with the new enum type `IRStage`, which has more values to indicate different IR stages of function compilation ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-aa860f71a793b08676a24cab247b43f5ed8d105a6493eeb1a035369b916bddc2L17-R17), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-aa860f71a793b08676a24cab247b43f5ed8d105a6493eeb1a035369b916bddc2L32-R32), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-af3316673541832f351d12d7c2f45b3c49ba5caeafdad3a6356cb13d2524be3dL9-R20), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-af3316673541832f351d12d7c2f45b3c49ba5caeafdad3a6356cb13d2524be3dL31-R50), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-f78d8ce92dcf8a10d2a446d35cc26f47fd2a42314b0799d263196b6eb858fe76L13-R33), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-f78d8ce92dcf8a10d2a446d35cc26f47fd2a42314b0799d263196b6eb858fe76L39-R48), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL330-R390)) * Modify the signature of the `compile_function` function to use the new parameter `target_stage` instead of the old parameter `start_from_ast`, to indicate the desired IR stage of the function compilation ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L199-R200), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL330-R390)) * Modify the definition of the `compile_to_offloads` function to add two calls to the new analysis pass `gather_func_store_dests`, before and after the call to the `compile_taichi_functions` function, and to pass different `target_stage` parameters to the `compile_taichi_functions` function ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL47-R51)) * Add or modify the include directives and forward declarations for the header files `function.h`, `statements.h`, and `unordered_set` in the source files and header files that use the `Function` class, the `FuncCallStmt` class, or the `std::unordered_set` container ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-837b90142d1730f6a3ab20c91f1f35c95335ef82a021c74fd4dbdb05ff0e164fR9), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-67e7205404aa056a1553f930af38b359e460f98a4ec335faec7d54aaf9df727fR10), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-05e2a2d0a9c9879a4fb5fde9baf5a43738c7601fc53e234a40ab9bc27d1512a5R5), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934R20), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-af3316673541832f351d12d7c2f45b3c49ba5caeafdad3a6356cb13d2524be3dR3)) * Modify some comments in the header file `transforms.h` to remove the mentions of not demoting dense struct fors or reducing the number of statements before inlining, since these are no longer relevant or necessary after the new analysis pass ([link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L160-R161), [link](https://github.com/taichi-dev/taichi/pull/8155/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L192-R190)) --- taichi/analysis/gather_func_store_dests.cpp | 103 ++++++++++++++++++ taichi/ir/analysis.h | 2 +- taichi/ir/control_flow_graph.cpp | 17 ++- taichi/ir/control_flow_graph.h | 3 + taichi/ir/statements.cpp | 14 +++ taichi/ir/statements.h | 9 +- taichi/ir/transforms.h | 15 ++- taichi/program/function.cpp | 4 +- taichi/program/function.h | 22 +++- .../transforms/compile_taichi_functions.cpp | 27 +++-- taichi/transforms/compile_to_offloads.cpp | 81 ++++++++------ 11 files changed, 223 insertions(+), 74 deletions(-) create mode 100644 taichi/analysis/gather_func_store_dests.cpp diff --git a/taichi/analysis/gather_func_store_dests.cpp b/taichi/analysis/gather_func_store_dests.cpp new file mode 100644 index 00000000000000..bdf7211fa0c9b7 --- /dev/null +++ b/taichi/analysis/gather_func_store_dests.cpp @@ -0,0 +1,103 @@ +#include +#include "taichi/ir/analysis.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/visitors.h" +#include "taichi/ir/control_flow_graph.h" +#include "taichi/program/function.h" + +namespace taichi::lang { + +class GatherFuncStoreDests : public BasicStmtVisitor { + private: + std::unordered_set results_; + Function *current_func_; + struct TarjanData { + std::unordered_map func_dfn; + std::unordered_map func_low; + std::unordered_set func_in_stack; + std::stack func_stack; + }; + TarjanData &tarjan_data_; + + static std::unordered_set run(Function *func, + TarjanData &tarjan_data) { + TI_ASSERT(tarjan_data.func_dfn.count(func) == 0); + tarjan_data.func_dfn[func] = tarjan_data.func_low[func] = + tarjan_data.func_dfn.size(); + tarjan_data.func_in_stack.insert(func); + tarjan_data.func_stack.push(func); + GatherFuncStoreDests searcher(func, tarjan_data); + func->ir->accept(&searcher); + if (tarjan_data.func_low[func] == tarjan_data.func_dfn[func]) { + while (true) { + auto top = tarjan_data.func_stack.top(); + tarjan_data.func_stack.pop(); + tarjan_data.func_in_stack.erase(top); + top->store_dests.insert(searcher.results_.begin(), + searcher.results_.end()); + if (top == func) { + break; + } + } + } + return searcher.results_; + } + + static void run(IRNode *ir, TarjanData &tarjan_data) { + GatherFuncStoreDests searcher(nullptr, tarjan_data); + ir->accept(&searcher); + } + + public: + using BasicStmtVisitor::visit; + + GatherFuncStoreDests(Function *func, TarjanData &tarjan_data) + : current_func_(func), tarjan_data_(tarjan_data) { + allow_undefined_visitor = true; + invoke_default_visitor = true; + } + + void visit(Stmt *stmt) override { + if (!current_func_) { + return; + } + auto result = irpass::analysis::get_store_destination(stmt); + results_.insert(result.begin(), result.end()); + } + + void visit(FuncCallStmt *stmt) override { + auto func = stmt->func; + if (!current_func_) { + if (!tarjan_data_.func_dfn.count(func)) { + run(func, tarjan_data_); + } + return; + } + if (!tarjan_data_.func_dfn.count(func)) { + auto result = run(func, tarjan_data_); + results_.merge(result); + tarjan_data_.func_low[current_func_] = std::min( + tarjan_data_.func_low[current_func_], tarjan_data_.func_low[func]); + } else if (tarjan_data_.func_in_stack.count(func)) { + tarjan_data_.func_low[current_func_] = std::min( + tarjan_data_.func_low[current_func_], tarjan_data_.func_dfn[func]); + } else { + const auto &dests = func->store_dests; + results_.insert(dests.begin(), dests.end()); + } + } + + static void run(IRNode *ir) { + TarjanData tarjan_data; + run(ir, tarjan_data); + } +}; + +namespace irpass::analysis { +void gather_func_store_dests(IRNode *ir) { + GatherFuncStoreDests::run(ir); +} + +} // namespace irpass::analysis + +} // namespace taichi::lang diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index 48eb0b4cae1281..996f63c14825c0 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -213,7 +213,7 @@ std::unique_ptr initialize_mesh_local_attribute( OffloadedStmt *offload, bool auto_mesh_local, const CompileConfig &config); - +void gather_func_store_dests(IRNode *ir); } // namespace analysis } // namespace irpass } // namespace taichi::lang diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index 7f777f547a3015..5f8bd03fa22e99 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -6,6 +6,7 @@ #include "taichi/ir/analysis.h" #include "taichi/ir/statements.h" #include "taichi/system/profiler.h" +#include "taichi/program/function.h" namespace taichi::lang { @@ -161,10 +162,6 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { // [Intra-block Search] int last_def_position = -1; for (int i = position - 1; i >= begin_location; i--) { - if (block->statements[i]->is()) { - return nullptr; - } - // Find previous store stmt to the same dest_addr, stop at the closest one. // store_ptr: prev-store dest_addr for (auto store_ptr : @@ -216,10 +213,7 @@ Stmt *CFGNode::get_store_forwarding_data(Stmt *var, int position) const { } // Check if store_stmt will ever influence the value of var - auto may_contain_address = [](Stmt *store_stmt, Stmt *var) { - if (store_stmt->is()) { - return true; - } + auto may_contain_address = [&](Stmt *store_stmt, Stmt *var) { for (auto store_ptr : irpass::analysis::get_store_destination(store_stmt)) { if (var->is() && !store_ptr->is()) { // check for aliased address with var @@ -698,6 +692,7 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { if (stmt->is()) { killed_in_this_node.clear(); live_load_in_this_node.clear(); + continue; } auto store_ptrs = irpass::analysis::get_store_destination(stmt); @@ -979,8 +974,7 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { for (int i = 0; i < num_nodes; i++) { for (int j = nodes[i]->begin_location; j < nodes[i]->end_location; j++) { auto stmt = nodes[i]->block->statements[j].get(); - if (stmt->is() || - (stmt->is() && + if ((stmt->is() && stmt->as()->origin->is()) || (!after_lower_access && (stmt->is() || stmt->is() || @@ -991,6 +985,9 @@ void ControlFlowGraph::reaching_definition_analysis(bool after_lower_access) { // TODO: unify them // A global pointer that may contain some data before this kernel. nodes[start_node]->reach_gen.insert(stmt); + } else if (auto func_call = stmt->cast()) { + const auto &dests = func_call->func->store_dests; + nodes[start_node]->reach_gen.insert(dests.begin(), dests.end()); } } } diff --git a/taichi/ir/control_flow_graph.h b/taichi/ir/control_flow_graph.h index f51ec732e2b176..324e9a93313936 100644 --- a/taichi/ir/control_flow_graph.h +++ b/taichi/ir/control_flow_graph.h @@ -7,6 +7,7 @@ namespace taichi::lang { +class Function; /** * A basic block in control-flow graph. * A CFGNode contains a reference to a part of the CHI IR, or more precisely, @@ -113,6 +114,8 @@ class ControlFlowGraph { const int start_node = 0; int final_node{0}; + std::unordered_map> func_store_dests; + template CFGNode *push_back(Args &&...args) { nodes.emplace_back(std::make_unique(std::forward(args)...)); diff --git a/taichi/ir/statements.cpp b/taichi/ir/statements.cpp index 3bbe33893dab58..bec9808a7a7f15 100644 --- a/taichi/ir/statements.cpp +++ b/taichi/ir/statements.cpp @@ -2,6 +2,7 @@ #include "taichi/ir/statements.h" #include "taichi/util/bit.h" #include "taichi/program/kernel.h" +#include "taichi/program/function.h" namespace taichi::lang { @@ -278,6 +279,19 @@ FuncCallStmt::FuncCallStmt(Function *func, const std::vector &args) TI_STMT_REG_FIELDS; } +stmt_refs FuncCallStmt::get_store_destination() const { + std::vector ret; + for (auto &arg : args) { + if (auto ref = arg->cast()) { + ret.push_back(ref->var); + } else if (arg->ret_type.is_pointer()) { + ret.push_back(arg); + } + } + ret.insert(ret.end(), func->store_dests.begin(), func->store_dests.end()); + return ret; +} + WhileStmt::WhileStmt(std::unique_ptr &&body) : mask(nullptr), body(std::move(body)) { this->body->set_parent_stmt(this); diff --git a/taichi/ir/statements.h b/taichi/ir/statements.h index 57280c28a6c750..2de685936bf182 100644 --- a/taichi/ir/statements.h +++ b/taichi/ir/statements.h @@ -1062,7 +1062,7 @@ class MeshForStmt : public Stmt { /** * Call an inline Taichi function. */ -class FuncCallStmt : public Stmt { +class FuncCallStmt : public Stmt, public ir_traits::Store { public: Function *func; std::vector args; @@ -1074,6 +1074,13 @@ class FuncCallStmt : public Stmt { return global_side_effect; } + // IR Trait: Store + stmt_refs get_store_destination() const override; + + Stmt *get_store_data() const override { + return nullptr; + } + TI_STMT_DEF_FIELDS(ret_type, func, args); TI_DEFINE_ACCEPT_AND_CLONE }; diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index fede01a04b80ec..51030dc1e7aa56 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -17,6 +17,7 @@ #include "taichi/transforms/demote_mesh_statements.h" #include "taichi/transforms/simplify.h" #include "taichi/common/trait.h" +#include "taichi/program/function.h" namespace taichi::lang { @@ -158,10 +159,7 @@ std::unordered_map detect_external_ptr_access_in_task( OffloadedStmt *offload); // compile_to_offloads does the basic compilation to create all the offloaded -// tasks of a Taichi kernel. It's worth pointing out that this doesn't demote -// dense struct fors. This is a necessary workaround to prevent the async -// engine from fusing incompatible offloaded tasks. TODO(Lin): check this -// comment +// tasks of a Taichi kernel. void compile_to_offloads(IRNode *ir, const CompileConfig &config, const Kernel *kernel, @@ -190,16 +188,17 @@ void compile_to_executable(IRNode *ir, bool make_thread_local = false, bool make_block_local = false, bool start_from_ast = true); -// Compile a function with some basic optimizations, so that the number of -// statements is reduced before inlining. +// Compile a function with some basic optimizations void compile_function(IRNode *ir, const CompileConfig &config, Function *func, AutodiffMode autodiff_mode, bool verbose, - bool start_from_ast); + Function::IRStage target_stage); -void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config); +void compile_taichi_functions(IRNode *ir, + const CompileConfig &compile_config, + Function::IRStage target_stage); } // namespace irpass } // namespace taichi::lang diff --git a/taichi/program/function.cpp b/taichi/program/function.cpp index c4de4c4056aa14..b06370089e0ca1 100644 --- a/taichi/program/function.cpp +++ b/taichi/program/function.cpp @@ -14,7 +14,7 @@ Function::Function(Program *program, const FunctionKey &func_key) void Function::set_function_body(const std::function &func) { context = std::make_unique(program->compile_config().arch); ir = context->get_root(); - ir_type_ = IRType::AST; + ir_stage_ = IRStage::AST; func(); finalize_params(); @@ -29,7 +29,7 @@ void Function::set_function_body(const std::function &func) { void Function::set_function_body(std::unique_ptr func_body) { ir = std::move(func_body); - ir_type_ = IRType::InitialIR; + ir_stage_ = IRStage::InitialIR; } std::string Function::get_name() const { diff --git a/taichi/program/function.h b/taichi/program/function.h index 712f79f57dba32..15d3ba57cdd948 100644 --- a/taichi/program/function.h +++ b/taichi/program/function.h @@ -1,15 +1,23 @@ #pragma once +#include #include "taichi/program/callable.h" #include "taichi/program/function_key.h" namespace taichi::lang { class Program; +class Stmt; class Function : public Callable { public: - enum class IRType { None, AST, InitialIR, OptimizedIR }; + enum class IRStage : int { + None = 0, + AST = 1, + InitialIR = 2, + BeforeLowerAccess = 3, + OptimizedIR = 4 + }; FunctionKey func_key; @@ -28,16 +36,18 @@ class Function : public Callable { return ast_serialization_data_; } - void set_ir_type(IRType type) { - ir_type_ = type; + void set_ir_stage(IRStage type) { + ir_stage_ = type; } - IRType ir_type() const { - return ir_type_; + IRStage ir_stage() const { + return ir_stage_; } + std::unordered_set store_dests; + private: - IRType ir_type_{IRType::None}; + IRStage ir_stage_{IRStage::None}; std::optional ast_serialization_data_; // For generating AST-Key }; diff --git a/taichi/transforms/compile_taichi_functions.cpp b/taichi/transforms/compile_taichi_functions.cpp index 25d9e1d28df6c7..180d191f49dd02 100644 --- a/taichi/transforms/compile_taichi_functions.cpp +++ b/taichi/transforms/compile_taichi_functions.cpp @@ -10,39 +10,42 @@ class CompileTaichiFunctions : public BasicStmtVisitor { public: using BasicStmtVisitor::visit; - explicit CompileTaichiFunctions(const CompileConfig &compile_config) - : compile_config_(compile_config) { + CompileTaichiFunctions(const CompileConfig &compile_config, + Function::IRStage target_stage) + : compile_config_(compile_config), target_stage_(target_stage) { } void visit(FuncCallStmt *stmt) override { - using IRType = Function::IRType; auto *func = stmt->func; - const auto ir_type = func->ir_type(); - if (ir_type != IRType::OptimizedIR) { - TI_ASSERT(ir_type == IRType::AST || ir_type == IRType::InitialIR); - func->set_ir_type(IRType::OptimizedIR); + const auto ir_type = func->ir_stage(); + if (ir_type < target_stage_) { irpass::compile_function(func->ir.get(), compile_config_, func, /*autodiff_mode=*/AutodiffMode::kNone, /*verbose=*/compile_config_.print_ir, - /*start_from_ast=*/ir_type == IRType::AST); + target_stage_); func->ir->accept(this); } } - static void run(IRNode *ir, const CompileConfig &compile_config) { - CompileTaichiFunctions ctf{compile_config}; + static void run(IRNode *ir, + const CompileConfig &compile_config, + Function::IRStage target_stage) { + CompileTaichiFunctions ctf{compile_config, target_stage}; ir->accept(&ctf); } private: const CompileConfig &compile_config_; + Function::IRStage target_stage_; }; namespace irpass { -void compile_taichi_functions(IRNode *ir, const CompileConfig &compile_config) { +void compile_taichi_functions(IRNode *ir, + const CompileConfig &compile_config, + Function::IRStage target_stage) { TI_AUTO_PROF; - CompileTaichiFunctions::run(ir, compile_config); + CompileTaichiFunctions::run(ir, compile_config, target_stage); } } // namespace irpass diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 64d03ac686cdcb..4c770881706ed1 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -44,7 +44,11 @@ void compile_to_offloads(IRNode *ir, print("Lowered"); } - irpass::compile_taichi_functions(ir, config); + irpass::compile_taichi_functions(ir, config, + Function::IRStage::BeforeLowerAccess); + irpass::analysis::gather_func_store_dests(ir); + irpass::compile_taichi_functions(ir, config, Function::IRStage::OptimizedIR); + irpass::analysis::gather_func_store_dests(ir); irpass::eliminate_immutable_local_vars(ir); print("Immutable local vars eliminated"); @@ -330,54 +334,63 @@ void compile_function(IRNode *ir, Function *func, AutodiffMode autodiff_mode, bool verbose, - bool start_from_ast) { + Function::IRStage target_stage) { TI_AUTO_PROF; + auto current_stage = func->ir_stage(); auto print = make_pass_printer(verbose, func->get_name(), ir); print("Initial IR"); - if (autodiff_mode == AutodiffMode::kReverse) { - irpass::reverse_segments(ir); - print("Segment reversed (for autodiff)"); - } + if (target_stage >= Function::IRStage::BeforeLowerAccess && + current_stage < Function::IRStage::BeforeLowerAccess) { + if (autodiff_mode == AutodiffMode::kReverse) { + irpass::reverse_segments(ir); + print("Segment reversed (for autodiff)"); + } - if (start_from_ast) { - irpass::frontend_type_check(ir); - irpass::lower_ast(ir); - print("Lowered"); - } + if (current_stage < Function::IRStage::InitialIR) { + irpass::frontend_type_check(ir); + irpass::lower_ast(ir); + print("Lowered"); + } - if (config.real_matrix_scalarize) { - if (irpass::scalarize(ir)) { - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::die(ir); - print("Scalarized"); + if (config.real_matrix_scalarize) { + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::die(ir); + print("Scalarized"); + } } + func->set_ir_stage(Function::IRStage::BeforeLowerAccess); } - irpass::lower_access(ir, config, {{}, true}); - print("Access lowered"); - irpass::analysis::verify(ir); + if (target_stage >= Function::IRStage::OptimizedIR && + current_stage < Function::IRStage::OptimizedIR) { + irpass::lower_access(ir, config, {{}, true}); + print("Access lowered"); + irpass::analysis::verify(ir); - irpass::die(ir); - print("DIE"); - irpass::analysis::verify(ir); + irpass::die(ir); + print("DIE"); + irpass::analysis::verify(ir); - irpass::flag_access(ir); - print("Access flagged III"); - irpass::analysis::verify(ir); + irpass::flag_access(ir); + print("Access flagged III"); + irpass::analysis::verify(ir); - irpass::type_check(ir, config); - print("Typechecked"); + irpass::type_check(ir, config); + print("Typechecked"); - irpass::demote_operations(ir, config); - print("Operations demoted"); + irpass::demote_operations(ir, config); + print("Operations demoted"); - irpass::full_simplify( - ir, config, - {false, autodiff_mode != AutodiffMode::kNone, func->get_name(), verbose}); - print("Simplified"); - irpass::analysis::verify(ir); + irpass::full_simplify(ir, config, + {true, autodiff_mode != AutodiffMode::kNone, + func->get_name(), verbose}); + print("Simplified"); + irpass::analysis::verify(ir); + func->set_ir_stage(Function::IRStage::OptimizedIR); + } } } // namespace irpass