diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index c93fd9298b1d8..5913b64aceedc 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -127,14 +127,6 @@ void compile_to_offloads(IRNode *ir, print("Access flagged I"); irpass::analysis::verify(ir); - if (config.real_matrix_scalarize) { - irpass::scalarize(ir); - - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::die(ir); - print("Scalarized"); - } - irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); print("Simplified II"); irpass::analysis::verify(ir); @@ -143,6 +135,14 @@ void compile_to_offloads(IRNode *ir, print("Offloaded"); irpass::analysis::verify(ir); + if (config.real_matrix_scalarize) { + irpass::scalarize(ir); + + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::die(ir); + print("Scalarized"); + } + // TODO: This pass may be redundant as cfg_optimization() is already called // in full_simplify(). if (config.opt_level > 0 && config.cfg_optimization) { diff --git a/taichi/transforms/offload.cpp b/taichi/transforms/offload.cpp index 33e4305ca37be..d9f904b443caf 100644 --- a/taichi/transforms/offload.cpp +++ b/taichi/transforms/offload.cpp @@ -533,24 +533,19 @@ class FixCrossOffloadReferences : public BasicStmtVisitor { local_to_global_offset_.at(stmt), ret_type); auto offloaded = stmt_to_offloaded_[stmt]; stmt_to_offloaded_[ptr] = offloaded; + + TypedConstant zero(stmt->ret_type.get_element_type()); + auto const_zero_stmt = replacement.push_back(zero); if (auto tensor_type = stmt->ret_type->cast()) { - TypedConstant zero(tensor_type->get_element_type()); - auto const_zero_stmt = replacement.push_back(zero); - stmt_to_offloaded_[const_zero_stmt] = offloaded; - for (int i = 0; i < tensor_type->get_num_elements(); ++i) { - auto const_offset_stmt = - replacement.push_back(TypedConstant(i)); - auto ptr_offset_stmt = - replacement.push_back(ptr, const_offset_stmt); - auto global_store_stmt = replacement.push_back( - ptr_offset_stmt, const_zero_stmt); - stmt_to_offloaded_[const_offset_stmt] = offloaded; - stmt_to_offloaded_[ptr_offset_stmt] = offloaded; - stmt_to_offloaded_[global_store_stmt] = offloaded; - } + std::vector zero_values(tensor_type->get_num_elements(), + const_zero_stmt); + auto zero_matrix_init_stmt = + replacement.push_back(zero_values); + zero_matrix_init_stmt->ret_type = stmt->ret_type.ptr_removed(); + auto global_store_stmt = + replacement.push_back(ptr, zero_matrix_init_stmt); + stmt_to_offloaded_[global_store_stmt] = offloaded; } else { - TypedConstant zero(stmt->ret_type); - auto const_zero_stmt = replacement.push_back(zero); auto global_store_stmt = replacement.push_back(ptr, const_zero_stmt); stmt_to_offloaded_[global_store_stmt] = offloaded; diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index 40fa40f8ba9f6..444c63097f770 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -881,7 +881,7 @@ class GatherScalarizableLocalPointers : public BasicStmtVisitor { } }; -class ScalarizeLocalPointers : public BasicStmtVisitor { +class ScalarizePointers : public BasicStmtVisitor { public: ImmediateIRModifier immediate_modifier_; DelayedIRModifier delayed_modifier_; @@ -890,7 +890,7 @@ class ScalarizeLocalPointers : public BasicStmtVisitor { // { original_alloca_stmt : [scalarized_alloca_stmt0, ...] } std::unordered_map> scalarized_local_tensor_map_; - explicit ScalarizeLocalPointers( + explicit ScalarizePointers( IRNode *node, const std::unordered_set &scalarizable_allocas) : immediate_modifier_(node), scalarizable_allocas_(scalarizable_allocas) { @@ -948,16 +948,16 @@ class ScalarizeLocalPointers : public BasicStmtVisitor { } } - /* - Before: - MatrixPtrStmt(TensorType<4 x i32>* alloca_stmt, int offset) - - After: - scalarized_alloca_stmt = - scalarized_local_tensor_map_[alloca_stmt][offset] - stmt->replace_all_usages_with(scalarized_alloca_stmt) - */ void visit(MatrixPtrStmt *stmt) override { + /* + Before: + MatrixPtrStmt(TensorType<4 x i32>* alloca_stmt, int offset) + + After: + scalarized_alloca_stmt = + scalarized_local_tensor_map_[alloca_stmt][offset] + stmt->replace_all_usages_with(scalarized_alloca_stmt) + */ if (stmt->origin->is() && scalarizable_allocas_.count(stmt->origin) == 1) { auto alloca_stmt = stmt->origin->cast(); @@ -979,6 +979,34 @@ class ScalarizeLocalPointers : public BasicStmtVisitor { immediate_modifier_.replace_usages_with(stmt, new_stmt); delayed_modifier_.erase(stmt); + return; + } + + /* + Before: + TensorType<4 x i32>* ptr = GlobalTempStmt(offset_0) + i32* ptr_1 = MatrixPtrStmt(ptr, offset_1) + + After: + i32* $1 = GlobalTempStmt(offset_0 + offset_1 * sizeof(i32)) + replace_all_usages_with(ptr_1, $1) + */ + if (stmt->origin->is() && + stmt->offset->is()) { + auto global_temp_stmt = stmt->origin->as(); + auto offset_0 = global_temp_stmt->offset; + auto offset_1 = stmt->offset->as()->val.val_int32(); + auto new_offset = + offset_0 + offset_1 * data_type_size(stmt->ret_type.ptr_removed()); + + auto new_global_temp_stmt = std::make_unique( + new_offset, stmt->ret_type.ptr_removed().get_element_type()); + new_global_temp_stmt->ret_type.set_is_pointer(true); + + stmt->replace_usages_with(new_global_temp_stmt.get()); + delayed_modifier_.insert_before(stmt, std::move(new_global_temp_stmt)); + delayed_modifier_.erase(stmt); + return; } } @@ -1027,6 +1055,14 @@ class ExtractLocalPointers : public BasicStmtVisitor { delayed_modifier_.modify_ir(); } + void visit(OffloadedStmt *stmt) override { + // Extract to OffloadStmt + Block *orig_top_level = top_level_; + top_level_ = stmt->body.get(); + stmt->all_blocks_accept(this); + top_level_ = orig_top_level; + } + void visit(MatrixPtrStmt *stmt) override { if (stmt->origin->is()) { auto alloca_stmt = stmt->origin->cast(); @@ -1118,7 +1154,7 @@ void scalarize(IRNode *root) { TI_AUTO_PROF; Scalarize scalarize_pass(root); auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root); - ScalarizeLocalPointers scalarize_pointers_pass(root, scalarizable_allocas); + ScalarizePointers scalarize_pointers_pass(root, scalarizable_allocas); ExtractLocalPointers extract_pointers_pass(root); MergeExternalAndMatrixPtr::run(root); }