Skip to content

Commit

Permalink
[opt] Cache loop-invariant global vars to local vars (#6072)
Browse files Browse the repository at this point in the history
Related issue = fixes #5350 

Global variables can't be store-to-load forwarded after `lower-access`
pass, so we need to do `simplify` before it. It should speed up the
program in all circumstances.

Caching loop-invariant global vars to local vars sometimes speeds up the
program yet some time lets the program run slower so I let it controlled
by the compiler config.

FPS of Yu's program on RTX3080 on Vulkan:
Original: 19fps
Simplified before lower access: 30fps
Cached loop-invariant global vars to local vars: 41fps

**This PR does things as follows:**
1. Extract a base class `LoopInvariantDetector` from
`LoopInvariantCodeMotion`. This class maintains information to detect
whether a statement is loop-invariant.
2. Let LICM move `GlobalPtrStmt`, `ArgLoadStmt` and `ExternalPtrStmt`
out of the loop so that they become loop-invariant.
3. Add `CacheLoopInvariantGlobalVars` to move out loop-invariant global
variables that are loop-unique in the offloaded task.
4. Add pass `cache_loop_invariant_global_vars` after `demote_atomics`
before `demote_dense_struct_fors` (because loop-uniqueness can't be
correctly detected after `demote_dense_struct_fors`) and add a compiler
config flag to control it.
5. Add pass `full_simplify` before `lower_access` to enable
store-to-load forwarding for GlobalPtrs.

<!--
Thank you for your contribution!

If it is your first time contributing to Taichi, please read our
Contributor Guidelines:
  https://docs.taichi-lang.org/docs/contributor_guide

- Please always prepend your PR title with tags such as [CUDA], [Lang],
[Doc], [Example]. For a complete list of valid PR tags, please check out
https://github.com/taichi-dev/taichi/blob/master/misc/prtags.json.
- Use upper-case tags (e.g., [Metal]) for PRs that change public APIs.
Otherwise, please use lower-case tags (e.g., [metal]).
- More details:
https://docs.taichi-lang.org/docs/contributor_guide#pr-title-format-and-tags

- Please fill in the issue number that this PR relates to.
- If your PR fixes the issue **completely**, use the `close` or `fixes`
prefix so that GitHub automatically closes the issue when the PR is
merged. For example,
    Related issue = close #2345
- If the PR does not belong to any existing issue, free to leave it
blank.
-->

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Sep 23, 2022
1 parent b272c96 commit 8e9d978
Show file tree
Hide file tree
Showing 7 changed files with 357 additions and 110 deletions.
2 changes: 2 additions & 0 deletions taichi/ir/transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ bool whole_kernel_cse(IRNode *root);
bool extract_constant(IRNode *root, const CompileConfig &config);
bool unreachable_code_elimination(IRNode *root);
bool loop_invariant_code_motion(IRNode *root, const CompileConfig &config);
bool cache_loop_invariant_global_vars(IRNode *root,
const CompileConfig &config);
void full_simplify(IRNode *root,
const CompileConfig &config,
const FullSimplifyPass::Args &args);
Expand Down
1 change: 1 addition & 0 deletions taichi/program/compile_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ struct CompileConfig {
bool lower_access;
bool simplify_after_lower_access;
bool move_loop_invariant_outside_if;
bool cache_loop_invariant_global_vars{true};
bool demote_dense_struct_fors;
bool advanced_optimization;
bool constant_folding;
Expand Down
2 changes: 2 additions & 0 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ void export_lang(py::module &m) {
.def_readwrite("lower_access", &CompileConfig::lower_access)
.def_readwrite("move_loop_invariant_outside_if",
&CompileConfig::move_loop_invariant_outside_if)
.def_readwrite("cache_loop_invariant_global_vars",
&CompileConfig::cache_loop_invariant_global_vars)
.def_readwrite("default_cpu_block_dim",
&CompileConfig::default_cpu_block_dim)
.def_readwrite("cpu_block_dim_adaptive",
Expand Down
180 changes: 180 additions & 0 deletions taichi/transforms/cache_loop_invariant_global_vars.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,180 @@
#include "taichi/transforms/loop_invariant_detector.h"
#include "taichi/ir/analysis.h"

namespace taichi::lang {

class CacheLoopInvariantGlobalVars : public LoopInvariantDetector {
public:
using LoopInvariantDetector::visit;

enum class CacheStatus {
None = 0,
Read = 1,
Write = 2,
ReadWrite = 3,
};

typedef std::unordered_map<Stmt *, std::pair<CacheStatus, AllocaStmt *>>
CacheMap;
std::stack<CacheMap> cached_maps;

DelayedIRModifier modifier;
std::unordered_map<const SNode *, GlobalPtrStmt *> loop_unique_ptr_;
std::unordered_map<int, ExternalPtrStmt *> loop_unique_arr_ptr_;

OffloadedStmt *current_offloaded;

explicit CacheLoopInvariantGlobalVars(const CompileConfig &config)
: LoopInvariantDetector(config) {
}

void visit(OffloadedStmt *stmt) override {
if (stmt->task_type == OffloadedTaskType::range_for ||
stmt->task_type == OffloadedTaskType::mesh_for ||
stmt->task_type == OffloadedTaskType::struct_for) {
auto uniquely_accessed_pointers =
irpass::analysis::gather_uniquely_accessed_pointers(stmt);
loop_unique_ptr_ = std::move(uniquely_accessed_pointers.first);
loop_unique_arr_ptr_ = std::move(uniquely_accessed_pointers.second);
}
current_offloaded = stmt;
// We don't need to visit TLS/BLS prologues/epilogues.
if (stmt->body) {
if (stmt->task_type == OffloadedStmt::TaskType::range_for ||
stmt->task_type == OffloadedTaskType::mesh_for ||
stmt->task_type == OffloadedStmt::TaskType::struct_for)
visit_loop(stmt->body.get());
else
stmt->body->accept(this);
}
current_offloaded = nullptr;
}

bool is_offload_unique(Stmt *stmt) {
if (current_offloaded->task_type == OffloadedTaskType::serial) {
return true;
}
if (auto global_ptr = stmt->cast<GlobalPtrStmt>()) {
auto snode = global_ptr->snode;
if (loop_unique_ptr_[snode] == nullptr ||
loop_unique_ptr_[snode]->indices.empty()) {
// not uniquely accessed
return false;
}
if (current_offloaded->mem_access_opt.has_flag(
snode, SNodeAccessFlag::block_local) ||
current_offloaded->mem_access_opt.has_flag(
snode, SNodeAccessFlag::mesh_local)) {
// BLS does not support write access yet so we keep atomic_adds.
return false;
}
return true;
} else if (stmt->is<ExternalPtrStmt>()) {
ExternalPtrStmt *dest_ptr = stmt->as<ExternalPtrStmt>();
if (dest_ptr->indices.empty()) {
return false;
}
ArgLoadStmt *arg_load_stmt = dest_ptr->base_ptr->as<ArgLoadStmt>();
int arg_id = arg_load_stmt->arg_id;
if (loop_unique_arr_ptr_[arg_id] == nullptr) {
// Not loop unique
return false;
}
return true;
// TODO: Is BLS / Mem Access Opt a thing for any_arr?
}
return false;
}

void visit_loop(Block *body) override {
cached_maps.emplace();
LoopInvariantDetector::visit_loop(body);
cached_maps.pop();
}

void add_writeback(AllocaStmt *alloca_stmt, Stmt *global_var) {
auto final_value = std::make_unique<LocalLoadStmt>(alloca_stmt);
auto global_store =
std::make_unique<GlobalStoreStmt>(global_var, final_value.get());
modifier.insert_after(current_loop_stmt(), std::move(global_store));
modifier.insert_after(current_loop_stmt(), std::move(final_value));
}

void set_init_value(AllocaStmt *alloca_stmt, Stmt *global_var) {
auto new_global_load = std::make_unique<GlobalLoadStmt>(global_var);
auto local_store =
std::make_unique<LocalStoreStmt>(alloca_stmt, new_global_load.get());
modifier.insert_before(current_loop_stmt(), std::move(new_global_load));
modifier.insert_before(current_loop_stmt(), std::move(local_store));
}

AllocaStmt *cache_global_to_local(Stmt *dest, CacheStatus status) {
if (auto &[cached_status, alloca_stmt] = cached_maps.top()[dest];
cached_status != CacheStatus::None) {
// The global variable has already been cached.
if (cached_status == CacheStatus::Read && status == CacheStatus::Write) {
add_writeback(alloca_stmt, dest);
cached_status = CacheStatus::ReadWrite;
}
return alloca_stmt;
}
auto alloca_unique =
std::make_unique<AllocaStmt>(dest->ret_type.ptr_removed());
auto alloca_stmt = alloca_unique.get();
modifier.insert_before(current_loop_stmt(), std::move(alloca_unique));
if (status == CacheStatus::Read) {
set_init_value(alloca_stmt, dest);
} else if (status == CacheStatus::Write) {
add_writeback(alloca_stmt, dest);
}
cached_maps.top()[dest] = {status, alloca_stmt};
return alloca_stmt;
}

void visit(GlobalLoadStmt *stmt) override {
if (is_offload_unique(stmt->src) &&
is_operand_loop_invariant(stmt->src, stmt->parent)) {
auto alloca_stmt = cache_global_to_local(stmt->src, CacheStatus::Read);
auto local_load = std::make_unique<LocalLoadStmt>(alloca_stmt);
stmt->replace_usages_with(local_load.get());
modifier.insert_before(stmt, std::move(local_load));
modifier.erase(stmt);
}
}

void visit(GlobalStoreStmt *stmt) override {
if (is_offload_unique(stmt->dest) &&
is_operand_loop_invariant(stmt->dest, stmt->parent)) {
auto alloca_stmt = cache_global_to_local(stmt->dest, CacheStatus::Write);
auto local_store =
std::make_unique<LocalStoreStmt>(alloca_stmt, stmt->val);
stmt->replace_usages_with(local_store.get());
modifier.insert_before(stmt, std::move(local_store));
modifier.erase(stmt);
}
}

static bool run(IRNode *node, const CompileConfig &config) {
bool modified = false;

while (true) {
CacheLoopInvariantGlobalVars eliminator(config);
node->accept(&eliminator);
if (eliminator.modifier.modify_ir())
modified = true;
else
break;
};

return modified;
}
};

namespace irpass {
bool cache_loop_invariant_global_vars(IRNode *root,
const CompileConfig &config) {
TI_AUTO_PROF;
return CacheLoopInvariantGlobalVars::run(root, config);
}
} // namespace irpass
} // namespace taichi::lang
7 changes: 7 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,10 @@ void offload_to_executable(IRNode *ir,
irpass::demote_atomics(ir, config);
print("Atomics demoted I");
irpass::analysis::verify(ir);
if (config.cache_loop_invariant_global_vars) {
irpass::cache_loop_invariant_global_vars(ir, config);
print("Cache loop-invariant global vars");
}

if (config.demote_dense_struct_fors) {
irpass::demote_dense_struct_fors(ir, config.packed);
Expand Down Expand Up @@ -246,6 +250,9 @@ void offload_to_executable(IRNode *ir,
irpass::analysis::verify(ir);

if (lower_global_access) {
irpass::full_simplify(ir, config,
{false, /*autodiff_enabled*/ false, kernel->program});
print("Simplified before lower access");
irpass::lower_access(ir, config, {kernel->no_activate, true});
print("Access lowered");
irpass::analysis::verify(ir);
Expand Down
Loading

0 comments on commit 8e9d978

Please sign in to comment.