diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 2c2ecfddee2d9..e8f96c339a52c 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -652,6 +652,77 @@ class ScalarizePointers : public BasicStmtVisitor { using BasicStmtVisitor::visit; }; +// The ExtractPointers pass aims at removing redundant ConstStmts and +// MatrixPtrStmts generated for any (AllocaStmt, integer) pair by extracting +// a unique copy for any future usage. +// +// Example for redundant stmts: +// $0 = const 0 +// $1 = const 1 +// ... +// <[Tensor (3, 3) f32]> $47738 = alloca +// $47739 = const 0 [REDUNDANT] +// <*f32> $47740 = shift ptr [$47738 + $47739] +// $47741 : local store [$47740 <- $47713] +// $47742 = const 1 [REDUNDANT] +// <*f32> $47743 = shift ptr [$47738 + $47742] +// $47744 : local store [$47743 <- $47716] +// ... +// $47812 = const 1 [REDUNDANT] +// <*f32> $47813 = shift ptr [$47738 + $47812] [REDUNDANT] +// $47814 = local load [$47813] +class ExtractPointers : public BasicStmtVisitor { + public: + ImmediateIRModifier immediate_modifier_; + DelayedIRModifier delayed_modifier_; + + std::unordered_map, + Stmt *, + hashing::Hasher>> + first_matrix_ptr_; // mapping an (AllocaStmt, integer) pair to the first + // MatrixPtrStmt representing it + std::unordered_map + first_const_; // mapping an integer to the first ConstStmt representing + // it + Block *top_level_; + + explicit ExtractPointers(IRNode *root) : immediate_modifier_(root) { + TI_ASSERT(root->is()); + top_level_ = root->as(); + root->accept(this); + delayed_modifier_.modify_ir(); + } + + void visit(MatrixPtrStmt *stmt) override { + if (stmt->origin->is()) { + auto alloca_stmt = stmt->origin->cast(); + auto tensor_type = + alloca_stmt->ret_type.ptr_removed()->cast(); + TI_ASSERT(tensor_type != nullptr); + if (stmt->offset->is()) { + int offset = stmt->offset->cast()->val.val_int32(); + if (first_const_.count(offset) == 0) { + first_const_[offset] = stmt->offset; + delayed_modifier_.extract_to_block_front(stmt->offset, top_level_); + } + auto key = std::make_pair(alloca_stmt, offset); + if (first_matrix_ptr_.count(key) == 0) { + auto extracted = std::make_unique( + alloca_stmt, first_const_[offset]); + first_matrix_ptr_[key] = extracted.get(); + delayed_modifier_.insert_after(alloca_stmt, std::move(extracted)); + } + auto new_stmt = first_matrix_ptr_[key]; + immediate_modifier_.replace_usages_with(stmt, new_stmt); + delayed_modifier_.erase(stmt); + } + } + } + + private: + using BasicStmtVisitor::visit; +}; + namespace irpass { void scalarize(IRNode *root, const CompileConfig &config) { @@ -659,6 +730,8 @@ void scalarize(IRNode *root, const CompileConfig &config) { Scalarize scalarize_pass(root); if (!config.dynamic_index) { ScalarizePointers scalarize_pointers_pass(root); + } else { + ExtractPointers extract_pointers_pass(root); } }