-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[opt] Cache loop-invariant global vars to local vars (#6072)
Related issue = fixes #5350 Global variables can't be store-to-load forwarded after `lower-access` pass, so we need to do `simplify` before it. It should speed up the program in all circumstances. Caching loop-invariant global vars to local vars sometimes speeds up the program yet some time lets the program run slower so I let it controlled by the compiler config. FPS of Yu's program on RTX3080 on Vulkan: Original: 19fps Simplified before lower access: 30fps Cached loop-invariant global vars to local vars: 41fps **This PR does things as follows:** 1. Extract a base class `LoopInvariantDetector` from `LoopInvariantCodeMotion`. This class maintains information to detect whether a statement is loop-invariant. 2. Let LICM move `GlobalPtrStmt`, `ArgLoadStmt` and `ExternalPtrStmt` out of the loop so that they become loop-invariant. 3. Add `CacheLoopInvariantGlobalVars` to move out loop-invariant global variables that are loop-unique in the offloaded task. 4. Add pass `cache_loop_invariant_global_vars` after `demote_atomics` before `demote_dense_struct_fors` (because loop-uniqueness can't be correctly detected after `demote_dense_struct_fors`) and add a compiler config flag to control it. 5. Add pass `full_simplify` before `lower_access` to enable store-to-load forwarding for GlobalPtrs. <!-- Thank you for your contribution! If it is your first time contributing to Taichi, please read our Contributor Guidelines: https://docs.taichi-lang.org/docs/contributor_guide - Please always prepend your PR title with tags such as [CUDA], [Lang], [Doc], [Example]. For a complete list of valid PR tags, please check out https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json. - Use upper-case tags (e.g., [Metal]) for PRs that change public APIs. Otherwise, please use lower-case tags (e.g., [metal]). - More details: https://docs.taichi-lang.org/docs/contributor_guide#pr-title-format-and-tags - Please fill in the issue number that this PR relates to. - If your PR fixes the issue **completely**, use the `close` or `fixes` prefix so that GitHub automatically closes the issue when the PR is merged. For example, Related issue = close #2345 - If the PR does not belong to any existing issue, free to leave it blank. --> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
b272c96
commit 8e9d978
Showing
7 changed files
with
357 additions
and
110 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,180 @@ | ||
#include "taichi/transforms/loop_invariant_detector.h" | ||
#include "taichi/ir/analysis.h" | ||
|
||
namespace taichi::lang { | ||
|
||
class CacheLoopInvariantGlobalVars : public LoopInvariantDetector { | ||
public: | ||
using LoopInvariantDetector::visit; | ||
|
||
enum class CacheStatus { | ||
None = 0, | ||
Read = 1, | ||
Write = 2, | ||
ReadWrite = 3, | ||
}; | ||
|
||
typedef std::unordered_map<Stmt *, std::pair<CacheStatus, AllocaStmt *>> | ||
CacheMap; | ||
std::stack<CacheMap> cached_maps; | ||
|
||
DelayedIRModifier modifier; | ||
std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_; | ||
std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_; | ||
|
||
OffloadedStmt *current_offloaded; | ||
|
||
explicit CacheLoopInvariantGlobalVars(const CompileConfig &config) | ||
: LoopInvariantDetector(config) { | ||
} | ||
|
||
void visit(OffloadedStmt *stmt) override { | ||
if (stmt->task_type == OffloadedTaskType::range_for || | ||
stmt->task_type == OffloadedTaskType::mesh_for || | ||
stmt->task_type == OffloadedTaskType::struct_for) { | ||
auto uniquely_accessed_pointers = | ||
irpass::analysis::gather_uniquely_accessed_pointers(stmt); | ||
loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first); | ||
loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second); | ||
} | ||
current_offloaded = stmt; | ||
// We don't need to visit TLS/BLS prologues/epilogues. | ||
if (stmt->body) { | ||
if (stmt->task_type == OffloadedStmt::TaskType::range_for || | ||
stmt->task_type == OffloadedTaskType::mesh_for || | ||
stmt->task_type == OffloadedStmt::TaskType::struct_for) | ||
visit_loop(stmt->body.get()); | ||
else | ||
stmt->body->accept(this); | ||
} | ||
current_offloaded = nullptr; | ||
} | ||
|
||
bool is_offload_unique(Stmt *stmt) { | ||
if (current_offloaded->task_type == OffloadedTaskType::serial) { | ||
return true; | ||
} | ||
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) { | ||
auto snode = global_ptr->snode; | ||
if (loop_unique_ptr_[snode] == nullptr || | ||
loop_unique_ptr_[snode]->indices.empty()) { | ||
// not uniquely accessed | ||
return false; | ||
} | ||
if (current_offloaded->mem_access_opt.has_flag( | ||
snode, SNodeAccessFlag::block_local) || | ||
current_offloaded->mem_access_opt.has_flag( | ||
snode, SNodeAccessFlag::mesh_local)) { | ||
// BLS does not support write access yet so we keep atomic_adds. | ||
return false; | ||
} | ||
return true; | ||
} else if (stmt->is<ExternalPtrStmt>()) { | ||
ExternalPtrStmt *dest_ptr = stmt->as<ExternalPtrStmt>(); | ||
if (dest_ptr->indices.empty()) { | ||
return false; | ||
} | ||
ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>(); | ||
int arg_id = arg_load_stmt->arg_id; | ||
if (loop_unique_arr_ptr_[arg_id] == nullptr) { | ||
// Not loop unique | ||
return false; | ||
} | ||
return true; | ||
// TODO: Is BLS / Mem Access Opt a thing for any_arr? | ||
} | ||
return false; | ||
} | ||
|
||
void visit_loop(Block *body) override { | ||
cached_maps.emplace(); | ||
LoopInvariantDetector::visit_loop(body); | ||
cached_maps.pop(); | ||
} | ||
|
||
void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var) { | ||
auto final_value = std::make_unique<LocalLoadStmt>(alloca_stmt); | ||
auto global_store = | ||
std::make_unique<GlobalStoreStmt>(global_var, final_value.get()); | ||
modifier.insert_after(current_loop_stmt(), std::move(global_store)); | ||
modifier.insert_after(current_loop_stmt(), std::move(final_value)); | ||
} | ||
|
||
void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var) { | ||
auto new_global_load = std::make_unique<GlobalLoadStmt>(global_var); | ||
auto local_store = | ||
std::make_unique<LocalStoreStmt>(alloca_stmt, new_global_load.get()); | ||
modifier.insert_before(current_loop_stmt(), std::move(new_global_load)); | ||
modifier.insert_before(current_loop_stmt(), std::move(local_store)); | ||
} | ||
|
||
AllocaStmt *cache_global_to_local(Stmt *dest, CacheStatus status) { | ||
if (auto &[cached_status, alloca_stmt] = cached_maps.top()[dest]; | ||
cached_status != CacheStatus::None) { | ||
// The global variable has already been cached. | ||
if (cached_status == CacheStatus::Read && status == CacheStatus::Write) { | ||
add_writeback(alloca_stmt, dest); | ||
cached_status = CacheStatus::ReadWrite; | ||
} | ||
return alloca_stmt; | ||
} | ||
auto alloca_unique = | ||
std::make_unique<AllocaStmt>(dest->ret_type.ptr_removed()); | ||
auto alloca_stmt = alloca_unique.get(); | ||
modifier.insert_before(current_loop_stmt(), std::move(alloca_unique)); | ||
if (status == CacheStatus::Read) { | ||
set_init_value(alloca_stmt, dest); | ||
} else if (status == CacheStatus::Write) { | ||
add_writeback(alloca_stmt, dest); | ||
} | ||
cached_maps.top()[dest] = {status, alloca_stmt}; | ||
return alloca_stmt; | ||
} | ||
|
||
void visit(GlobalLoadStmt *stmt) override { | ||
if (is_offload_unique(stmt->src) && | ||
is_operand_loop_invariant(stmt->src, stmt->parent)) { | ||
auto alloca_stmt = cache_global_to_local(stmt->src, CacheStatus::Read); | ||
auto local_load = std::make_unique<LocalLoadStmt>(alloca_stmt); | ||
stmt->replace_usages_with(local_load.get()); | ||
modifier.insert_before(stmt, std::move(local_load)); | ||
modifier.erase(stmt); | ||
} | ||
} | ||
|
||
void visit(GlobalStoreStmt *stmt) override { | ||
if (is_offload_unique(stmt->dest) && | ||
is_operand_loop_invariant(stmt->dest, stmt->parent)) { | ||
auto alloca_stmt = cache_global_to_local(stmt->dest, CacheStatus::Write); | ||
auto local_store = | ||
std::make_unique<LocalStoreStmt>(alloca_stmt, stmt->val); | ||
stmt->replace_usages_with(local_store.get()); | ||
modifier.insert_before(stmt, std::move(local_store)); | ||
modifier.erase(stmt); | ||
} | ||
} | ||
|
||
static bool run(IRNode *node, const CompileConfig &config) { | ||
bool modified = false; | ||
|
||
while (true) { | ||
CacheLoopInvariantGlobalVars eliminator(config); | ||
node->accept(&eliminator); | ||
if (eliminator.modifier.modify_ir()) | ||
modified = true; | ||
else | ||
break; | ||
}; | ||
|
||
return modified; | ||
} | ||
}; | ||
|
||
namespace irpass { | ||
bool cache_loop_invariant_global_vars(IRNode *root, | ||
const CompileConfig &config) { | ||
TI_AUTO_PROF; | ||
return CacheLoopInvariantGlobalVars::run(root, config); | ||
} | ||
} // namespace irpass | ||
} // namespace taichi::lang |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.