From 19fce81e0de83162d9d95bd67ed4c34a72bcce6f Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Wed, 21 Dec 2022 11:14:04 +0800 Subject: [PATCH] [opt] Add pass eliminate_immutable_local_vars (#6926) Issue: #6933 ### Brief Summary There are many redundant copies of local vars in the initial IR: ``` <[Tensor (3, 3) f32]> $128 = [$103, $106, $109, $112, $115, $118, $121, $124, $127] $129 : local store [$100 <- $128] <[Tensor (3, 3) f32]> $130 = alloca $131 = local load [$100] $132 : local store [$130 <- $131] <[Tensor (3, 3) f32]> $133 = alloca $134 = local load [$130] $135 : local store [$133 <- $134] <[Tensor (3, 3) f32]> $136 = alloca $137 = local load [$133] $138 : local store [$136 <- $137] // In fact, `$128` can be used wherever `$136` is loaded. ``` These can come from many places; one of the main sources is the pass-by-value convention of `ti.func`. The consequence is that the number of instructions is unnecessarily large, which significantly slows down compilation. My solution here is to identify and eliminate such redundant instructions in the first place so all later passes can take a much smaller number of instructions as input. These redundant local vars are essentially immutable ones - they are assigned only once and only loaded after the assignment. In this PR, I add an optimization pass `eliminate_immutable_local_vars` as the first pass. (P.S. The type check processes of `MatrixExpression` and `LocalLoadStmt` are fixed by the way to make the pass work properly.) Let's study the effects in two cases: #6933 and [voxel-rt2](https://github.com/taichi-dev/voxel-rt2/blob/main/example7.py). First, let's compare the number of instructions after `scalarization` pass (which happens immediately after the first pass). | Kernel | Before this PR | After this PR | Rate of decrease | | ------ | ------ | ------ | ------ | | `test` (#6933) | 45859 | 26452 | 42% | | `spatial_GRIS` (voxel-rt2) | 48519 | 17713 | 63% | Then, let's compare the total time of `compile()`. | Case | Before this PR | After this PR | Rate of decrease | | ------ | ------ | ------ | ------ | | #6933 | 20.622s | 8.550s | 59% | | voxel-rt2 | 27.676s | 9.495s | 66% | Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../analysis/gather_immutable_local_vars.cpp | 92 +++++++++++++++++++ taichi/ir/analysis.h | 1 + taichi/ir/frontend_ir.cpp | 10 +- taichi/ir/transforms.h | 1 + taichi/transforms/compile_to_offloads.cpp | 3 + .../eliminate_immutable_local_vars.cpp | 67 ++++++++++++++ 6 files changed, 171 insertions(+), 3 deletions(-) create mode 100644 taichi/analysis/gather_immutable_local_vars.cpp create mode 100644 taichi/transforms/eliminate_immutable_local_vars.cpp diff --git a/taichi/analysis/gather_immutable_local_vars.cpp b/taichi/analysis/gather_immutable_local_vars.cpp new file mode 100644 index 0000000000000..82234a87999af --- /dev/null +++ b/taichi/analysis/gather_immutable_local_vars.cpp @@ -0,0 +1,92 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" + +namespace taichi::lang { + +// The GatherImmutableLocalVars pass gathers all immutable local vars as input +// to the EliminateImmutableLocalVars pass. An immutable local var is an alloca +// which is stored only once (in the same block) and only loaded after that +// store. +class GatherImmutableLocalVars : public BasicStmtVisitor { + private: + using BasicStmtVisitor::visit; + + enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 }; + std::unordered_map alloca_status_; + + public: + explicit GatherImmutableLocalVars() { + invoke_default_visitor = true; + } + + void visit(AllocaStmt *stmt) override { + TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end()); + alloca_status_[stmt] = AllocaStatus::kCreated; + } + + void visit(LocalLoadStmt *stmt) override { + if (stmt->src->is()) { + auto status_iter = alloca_status_.find(stmt->src); + TI_ASSERT(status_iter != alloca_status_.end()); + if (status_iter->second == AllocaStatus::kCreated) { + status_iter->second = AllocaStatus::kInvalid; + } + } + } + + void visit(LocalStoreStmt *stmt) override { + if (stmt->dest->is()) { + auto status_iter = alloca_status_.find(stmt->dest); + TI_ASSERT(status_iter != alloca_status_.end()); + if (stmt->parent != stmt->dest->parent || + status_iter->second == AllocaStatus::kStoredOnce || + stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) { + status_iter->second = AllocaStatus::kInvalid; + } else if (status_iter->second == AllocaStatus::kCreated) { + status_iter->second = AllocaStatus::kStoredOnce; + } + } + } + + void default_visit(Stmt *stmt) { + for (auto &op : stmt->get_operands()) { + if (op != nullptr && op->is()) { + auto status_iter = alloca_status_.find(op); + TI_ASSERT(status_iter != alloca_status_.end()); + status_iter->second = AllocaStatus::kInvalid; + } + } + } + + void visit(Stmt *stmt) override { + default_visit(stmt); + } + + void preprocess_container_stmt(Stmt *stmt) override { + default_visit(stmt); + } + + static std::unordered_set run(IRNode *node) { + GatherImmutableLocalVars pass; + node->accept(&pass); + std::unordered_set result; + for (auto &[k, v] : pass.alloca_status_) { + if (v == AllocaStatus::kStoredOnce) { + result.insert(k); + } + } + return result; + } +}; + +namespace irpass::analysis { + +std::unordered_set gather_immutable_local_vars(IRNode *root) { + return GatherImmutableLocalVars::run(root); +} + +} // namespace irpass::analysis + +} // namespace taichi::lang diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index dc2159d453316..69d68ee4a32d1 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -95,6 +95,7 @@ bool definitely_same_address(Stmt *var1, Stmt *var2); std::unordered_set detect_fors_with_break(IRNode *root); std::unordered_set detect_loops_with_continue(IRNode *root); +std::unordered_set gather_immutable_local_vars(IRNode *root); std::unordered_set gather_deactivations(IRNode *root); std::pair, std::unordered_set> gather_snode_read_writes(IRNode *root); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 630bef729602f..97ecf385604de 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -670,9 +670,12 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, } void MatrixExpression::type_check(CompileConfig *config) { - // TODO: typecheck matrix for (auto &arg : elements) { TI_ASSERT_TYPE_CHECKED(arg); + if (arg->ret_type != dt.get_element_type()) { + arg = cast(arg, dt.get_element_type()); + arg->type_check(config); + } } ret_type = dt; } @@ -1569,8 +1572,9 @@ Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { } Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { - ctx->push_back(ptr_stmt); - return ctx->back_stmt(); + auto local_load = ctx->push_back(ptr_stmt); + local_load->ret_type = local_load->src->ret_type.ptr_removed(); + return local_load; } Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index ed8571fe02825..20bbb0533c841 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -29,6 +29,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); +void eliminate_immutable_local_vars(IRNode *root); void scalarize(IRNode *root, const CompileConfig &config); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 847b376801dbe..cf3a31e13a61d 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -52,6 +52,9 @@ void compile_to_offloads(IRNode *ir, print("Lowered"); } + irpass::eliminate_immutable_local_vars(ir); + print("Immutable local vars eliminated"); + if (config.real_matrix_scalarize) { irpass::scalarize(ir, config); diff --git a/taichi/transforms/eliminate_immutable_local_vars.cpp b/taichi/transforms/eliminate_immutable_local_vars.cpp new file mode 100644 index 0000000000000..036e96459f574 --- /dev/null +++ b/taichi/transforms/eliminate_immutable_local_vars.cpp @@ -0,0 +1,67 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" +#include "taichi/system/profiler.h" + +namespace taichi::lang { + +// The EliminateImmutableLocalVars pass eliminates all immutable local vars +// calculated from the GatherImmutableLocalVars pass. An immutable local var +// can be eliminated by forwarding the value of its only store to all loads +// after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the +// background of this optimization. +class EliminateImmutableLocalVars : public BasicStmtVisitor { + private: + using BasicStmtVisitor::visit; + + DelayedIRModifier modifier_; + std::unordered_set immutable_local_vars_; + std::unordered_map immutable_local_var_to_value_; + + public: + explicit EliminateImmutableLocalVars( + const std::unordered_set &immutable_local_vars) + : immutable_local_vars_(immutable_local_vars) { + } + + void visit(AllocaStmt *stmt) override { + if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) { + modifier_.erase(stmt); + } + } + + void visit(LocalLoadStmt *stmt) override { + if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) { + stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]); + modifier_.erase(stmt); + } + } + + void visit(LocalStoreStmt *stmt) override { + if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) { + TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) == + immutable_local_var_to_value_.end()); + immutable_local_var_to_value_[stmt->dest] = stmt->val; + modifier_.erase(stmt); + } + } + + static void run(IRNode *node) { + EliminateImmutableLocalVars pass( + irpass::analysis::gather_immutable_local_vars(node)); + node->accept(&pass); + pass.modifier_.modify_ir(); + } +}; + +namespace irpass { + +void eliminate_immutable_local_vars(IRNode *root) { + TI_AUTO_PROF; + EliminateImmutableLocalVars::run(root); +} + +} // namespace irpass + +} // namespace taichi::lang