-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[opt] Add pass eliminate_immutable_local_vars (#6926)
Issue: #6933 ### Brief Summary There are many redundant copies of local vars in the initial IR: ``` <[Tensor (3, 3) f32]> $128 = [$103, $106, $109, $112, $115, $118, $121, $124, $127] $129 : local store [$100 <- $128] <[Tensor (3, 3) f32]> $130 = alloca $131 = local load [$100] $132 : local store [$130 <- $131] <[Tensor (3, 3) f32]> $133 = alloca $134 = local load [$130] $135 : local store [$133 <- $134] <[Tensor (3, 3) f32]> $136 = alloca $137 = local load [$133] $138 : local store [$136 <- $137] // In fact, `$128` can be used wherever `$136` is loaded. ``` These can come from many places; one of the main sources is the pass-by-value convention of `ti.func`. The consequence is that the number of instructions is unnecessarily large, which significantly slows down compilation. My solution here is to identify and eliminate such redundant instructions in the first place so all later passes can take a much smaller number of instructions as input. These redundant local vars are essentially immutable ones - they are assigned only once and only loaded after the assignment. In this PR, I add an optimization pass `eliminate_immutable_local_vars` as the first pass. (P.S. The type check processes of `MatrixExpression` and `LocalLoadStmt` are fixed by the way to make the pass work properly.) Let's study the effects in two cases: #6933 and [voxel-rt2](https://github.com/taichi-dev/voxel-rt2/blob/main/example7.py). First, let's compare the number of instructions after `scalarization` pass (which happens immediately after the first pass). | Kernel | Before this PR | After this PR | Rate of decrease | | ------ | ------ | ------ | ------ | | `test` (#6933) | 45859 | 26452 | 42% | | `spatial_GRIS` (voxel-rt2) | 48519 | 17713 | 63% | Then, let's compare the total time of `compile()`. | Case | Before this PR | After this PR | Rate of decrease | | ------ | ------ | ------ | ------ | | #6933 | 20.622s | 8.550s | 59% | | voxel-rt2 | 27.676s | 9.495s | 66% | Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
- Loading branch information
1 parent
8501dcf
commit 19fce81
Showing
6 changed files
with
171 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#include "taichi/ir/ir.h" | ||
#include "taichi/ir/statements.h" | ||
#include "taichi/ir/transforms.h" | ||
#include "taichi/ir/visitors.h" | ||
|
||
namespace taichi::lang { | ||
|
||
// The GatherImmutableLocalVars pass gathers all immutable local vars as input | ||
// to the EliminateImmutableLocalVars pass. An immutable local var is an alloca | ||
// which is stored only once (in the same block) and only loaded after that | ||
// store. | ||
class GatherImmutableLocalVars : public BasicStmtVisitor { | ||
private: | ||
using BasicStmtVisitor::visit; | ||
|
||
enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 }; | ||
std::unordered_map<Stmt *, AllocaStatus> alloca_status_; | ||
|
||
public: | ||
explicit GatherImmutableLocalVars() { | ||
invoke_default_visitor = true; | ||
} | ||
|
||
void visit(AllocaStmt *stmt) override { | ||
TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end()); | ||
alloca_status_[stmt] = AllocaStatus::kCreated; | ||
} | ||
|
||
void visit(LocalLoadStmt *stmt) override { | ||
if (stmt->src->is<AllocaStmt>()) { | ||
auto status_iter = alloca_status_.find(stmt->src); | ||
TI_ASSERT(status_iter != alloca_status_.end()); | ||
if (status_iter->second == AllocaStatus::kCreated) { | ||
status_iter->second = AllocaStatus::kInvalid; | ||
} | ||
} | ||
} | ||
|
||
void visit(LocalStoreStmt *stmt) override { | ||
if (stmt->dest->is<AllocaStmt>()) { | ||
auto status_iter = alloca_status_.find(stmt->dest); | ||
TI_ASSERT(status_iter != alloca_status_.end()); | ||
if (stmt->parent != stmt->dest->parent || | ||
status_iter->second == AllocaStatus::kStoredOnce || | ||
stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) { | ||
status_iter->second = AllocaStatus::kInvalid; | ||
} else if (status_iter->second == AllocaStatus::kCreated) { | ||
status_iter->second = AllocaStatus::kStoredOnce; | ||
} | ||
} | ||
} | ||
|
||
void default_visit(Stmt *stmt) { | ||
for (auto &op : stmt->get_operands()) { | ||
if (op != nullptr && op->is<AllocaStmt>()) { | ||
auto status_iter = alloca_status_.find(op); | ||
TI_ASSERT(status_iter != alloca_status_.end()); | ||
status_iter->second = AllocaStatus::kInvalid; | ||
} | ||
} | ||
} | ||
|
||
void visit(Stmt *stmt) override { | ||
default_visit(stmt); | ||
} | ||
|
||
void preprocess_container_stmt(Stmt *stmt) override { | ||
default_visit(stmt); | ||
} | ||
|
||
static std::unordered_set<Stmt *> run(IRNode *node) { | ||
GatherImmutableLocalVars pass; | ||
node->accept(&pass); | ||
std::unordered_set<Stmt *> result; | ||
for (auto &[k, v] : pass.alloca_status_) { | ||
if (v == AllocaStatus::kStoredOnce) { | ||
result.insert(k); | ||
} | ||
} | ||
return result; | ||
} | ||
}; | ||
|
||
namespace irpass::analysis { | ||
|
||
std::unordered_set<Stmt *> gather_immutable_local_vars(IRNode *root) { | ||
return GatherImmutableLocalVars::run(root); | ||
} | ||
|
||
} // namespace irpass::analysis | ||
|
||
} // namespace taichi::lang |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
#include "taichi/ir/ir.h" | ||
#include "taichi/ir/statements.h" | ||
#include "taichi/ir/analysis.h" | ||
#include "taichi/ir/visitors.h" | ||
#include "taichi/system/profiler.h" | ||
|
||
namespace taichi::lang { | ||
|
||
// The EliminateImmutableLocalVars pass eliminates all immutable local vars | ||
// calculated from the GatherImmutableLocalVars pass. An immutable local var | ||
// can be eliminated by forwarding the value of its only store to all loads | ||
// after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the | ||
// background of this optimization. | ||
class EliminateImmutableLocalVars : public BasicStmtVisitor { | ||
private: | ||
using BasicStmtVisitor::visit; | ||
|
||
DelayedIRModifier modifier_; | ||
std::unordered_set<Stmt *> immutable_local_vars_; | ||
std::unordered_map<Stmt *, Stmt *> immutable_local_var_to_value_; | ||
|
||
public: | ||
explicit EliminateImmutableLocalVars( | ||
const std::unordered_set<Stmt *> &immutable_local_vars) | ||
: immutable_local_vars_(immutable_local_vars) { | ||
} | ||
|
||
void visit(AllocaStmt *stmt) override { | ||
if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) { | ||
modifier_.erase(stmt); | ||
} | ||
} | ||
|
||
void visit(LocalLoadStmt *stmt) override { | ||
if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) { | ||
stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]); | ||
modifier_.erase(stmt); | ||
} | ||
} | ||
|
||
void visit(LocalStoreStmt *stmt) override { | ||
if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) { | ||
TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) == | ||
immutable_local_var_to_value_.end()); | ||
immutable_local_var_to_value_[stmt->dest] = stmt->val; | ||
modifier_.erase(stmt); | ||
} | ||
} | ||
|
||
static void run(IRNode *node) { | ||
EliminateImmutableLocalVars pass( | ||
irpass::analysis::gather_immutable_local_vars(node)); | ||
node->accept(&pass); | ||
pass.modifier_.modify_ir(); | ||
} | ||
}; | ||
|
||
namespace irpass { | ||
|
||
void eliminate_immutable_local_vars(IRNode *root) { | ||
TI_AUTO_PROF; | ||
EliminateImmutableLocalVars::run(root); | ||
} | ||
|
||
} // namespace irpass | ||
|
||
} // namespace taichi::lang |