diff --git a/taichi/analysis/gather_immutable_local_vars.cpp b/taichi/analysis/gather_immutable_local_vars.cpp new file mode 100644 index 0000000000000..82234a87999af --- /dev/null +++ b/taichi/analysis/gather_immutable_local_vars.cpp @@ -0,0 +1,92 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/transforms.h" +#include "taichi/ir/visitors.h" + +namespace taichi::lang { + +// The GatherImmutableLocalVars pass gathers all immutable local vars as input +// to the EliminateImmutableLocalVars pass. An immutable local var is an alloca +// which is stored only once (in the same block) and only loaded after that +// store. +class GatherImmutableLocalVars : public BasicStmtVisitor { + private: + using BasicStmtVisitor::visit; + + enum class AllocaStatus { kCreated = 0, kStoredOnce = 1, kInvalid = 2 }; + std::unordered_map alloca_status_; + + public: + explicit GatherImmutableLocalVars() { + invoke_default_visitor = true; + } + + void visit(AllocaStmt *stmt) override { + TI_ASSERT(alloca_status_.find(stmt) == alloca_status_.end()); + alloca_status_[stmt] = AllocaStatus::kCreated; + } + + void visit(LocalLoadStmt *stmt) override { + if (stmt->src->is()) { + auto status_iter = alloca_status_.find(stmt->src); + TI_ASSERT(status_iter != alloca_status_.end()); + if (status_iter->second == AllocaStatus::kCreated) { + status_iter->second = AllocaStatus::kInvalid; + } + } + } + + void visit(LocalStoreStmt *stmt) override { + if (stmt->dest->is()) { + auto status_iter = alloca_status_.find(stmt->dest); + TI_ASSERT(status_iter != alloca_status_.end()); + if (stmt->parent != stmt->dest->parent || + status_iter->second == AllocaStatus::kStoredOnce || + stmt->val->ret_type != stmt->dest->ret_type.ptr_removed()) { + status_iter->second = AllocaStatus::kInvalid; + } else if (status_iter->second == AllocaStatus::kCreated) { + status_iter->second = AllocaStatus::kStoredOnce; + } + } + } + + void default_visit(Stmt *stmt) { + for (auto &op : stmt->get_operands()) { + if (op != nullptr && op->is()) { + auto status_iter = alloca_status_.find(op); + TI_ASSERT(status_iter != alloca_status_.end()); + status_iter->second = AllocaStatus::kInvalid; + } + } + } + + void visit(Stmt *stmt) override { + default_visit(stmt); + } + + void preprocess_container_stmt(Stmt *stmt) override { + default_visit(stmt); + } + + static std::unordered_set run(IRNode *node) { + GatherImmutableLocalVars pass; + node->accept(&pass); + std::unordered_set result; + for (auto &[k, v] : pass.alloca_status_) { + if (v == AllocaStatus::kStoredOnce) { + result.insert(k); + } + } + return result; + } +}; + +namespace irpass::analysis { + +std::unordered_set gather_immutable_local_vars(IRNode *root) { + return GatherImmutableLocalVars::run(root); +} + +} // namespace irpass::analysis + +} // namespace taichi::lang diff --git a/taichi/ir/analysis.h b/taichi/ir/analysis.h index dc2159d453316..69d68ee4a32d1 100644 --- a/taichi/ir/analysis.h +++ b/taichi/ir/analysis.h @@ -95,6 +95,7 @@ bool definitely_same_address(Stmt *var1, Stmt *var2); std::unordered_set detect_fors_with_break(IRNode *root); std::unordered_set detect_loops_with_continue(IRNode *root); +std::unordered_set gather_immutable_local_vars(IRNode *root); std::unordered_set gather_deactivations(IRNode *root); std::pair, std::unordered_set> gather_snode_read_writes(IRNode *root); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 630bef729602f..97ecf385604de 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -670,9 +670,12 @@ Stmt *make_tensor_access(Expression::FlattenContext *ctx, } void MatrixExpression::type_check(CompileConfig *config) { - // TODO: typecheck matrix for (auto &arg : elements) { TI_ASSERT_TYPE_CHECKED(arg); + if (arg->ret_type != dt.get_element_type()) { + arg = cast(arg, dt.get_element_type()); + arg->type_check(config); + } } ret_type = dt; } @@ -1569,8 +1572,9 @@ Stmt *flatten_global_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { } Stmt *flatten_local_load(Stmt *ptr_stmt, Expression::FlattenContext *ctx) { - ctx->push_back(ptr_stmt); - return ctx->back_stmt(); + auto local_load = ctx->push_back(ptr_stmt); + local_load->ret_type = local_load->src->ret_type.ptr_removed(); + return local_load; } Stmt *flatten_rvalue(Expr ptr, Expression::FlattenContext *ctx) { diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index ed8571fe02825..20bbb0533c841 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -29,6 +29,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); +void eliminate_immutable_local_vars(IRNode *root); void scalarize(IRNode *root, const CompileConfig &config); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 847b376801dbe..cf3a31e13a61d 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -52,6 +52,9 @@ void compile_to_offloads(IRNode *ir, print("Lowered"); } + irpass::eliminate_immutable_local_vars(ir); + print("Immutable local vars eliminated"); + if (config.real_matrix_scalarize) { irpass::scalarize(ir, config); diff --git a/taichi/transforms/eliminate_immutable_local_vars.cpp b/taichi/transforms/eliminate_immutable_local_vars.cpp new file mode 100644 index 0000000000000..036e96459f574 --- /dev/null +++ b/taichi/transforms/eliminate_immutable_local_vars.cpp @@ -0,0 +1,67 @@ +#include "taichi/ir/ir.h" +#include "taichi/ir/statements.h" +#include "taichi/ir/analysis.h" +#include "taichi/ir/visitors.h" +#include "taichi/system/profiler.h" + +namespace taichi::lang { + +// The EliminateImmutableLocalVars pass eliminates all immutable local vars +// calculated from the GatherImmutableLocalVars pass. An immutable local var +// can be eliminated by forwarding the value of its only store to all loads +// after that store. See https://github.com/taichi-dev/taichi/pull/6926 for the +// background of this optimization. +class EliminateImmutableLocalVars : public BasicStmtVisitor { + private: + using BasicStmtVisitor::visit; + + DelayedIRModifier modifier_; + std::unordered_set immutable_local_vars_; + std::unordered_map immutable_local_var_to_value_; + + public: + explicit EliminateImmutableLocalVars( + const std::unordered_set &immutable_local_vars) + : immutable_local_vars_(immutable_local_vars) { + } + + void visit(AllocaStmt *stmt) override { + if (immutable_local_vars_.find(stmt) != immutable_local_vars_.end()) { + modifier_.erase(stmt); + } + } + + void visit(LocalLoadStmt *stmt) override { + if (immutable_local_vars_.find(stmt->src) != immutable_local_vars_.end()) { + stmt->replace_usages_with(immutable_local_var_to_value_[stmt->src]); + modifier_.erase(stmt); + } + } + + void visit(LocalStoreStmt *stmt) override { + if (immutable_local_vars_.find(stmt->dest) != immutable_local_vars_.end()) { + TI_ASSERT(immutable_local_var_to_value_.find(stmt->dest) == + immutable_local_var_to_value_.end()); + immutable_local_var_to_value_[stmt->dest] = stmt->val; + modifier_.erase(stmt); + } + } + + static void run(IRNode *node) { + EliminateImmutableLocalVars pass( + irpass::analysis::gather_immutable_local_vars(node)); + node->accept(&pass); + pass.modifier_.modify_ir(); + } +}; + +namespace irpass { + +void eliminate_immutable_local_vars(IRNode *root) { + TI_AUTO_PROF; + EliminateImmutableLocalVars::run(root); +} + +} // namespace irpass + +} // namespace taichi::lang