From 810c17d8c4405b79d21302e14e9de7797c4c448b Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 10:37:36 +0800 Subject: [PATCH 01/11] [Lang] Migrate irpass::scalarize() after irpass::make_block_local() --- taichi/transforms/compile_to_offloads.cpp | 10 ++++---- taichi/transforms/make_block_local.cpp | 31 +++++++++++++++++++++++ 2 files changed, 36 insertions(+), 5 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index f231167e952a3..5a95b78652ff6 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -226,6 +226,11 @@ void offload_to_executable(IRNode *ir, } } + if (make_block_local) { + irpass::make_block_local(ir, config, {kernel->get_name()}); + print("Make block local"); + } + if (config.real_matrix_scalarize) { irpass::scalarize(ir); @@ -234,11 +239,6 @@ void offload_to_executable(IRNode *ir, print("Scalarized"); } - if (make_block_local) { - irpass::make_block_local(ir, config, {kernel->get_name()}); - print("Make block local"); - } - if (is_extension_supported(config.arch, Extension::mesh)) { irpass::demote_mesh_statements(ir, config, {kernel->get_name()}); print("Demote mesh statements"); diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index 314b1b418cbf5..c569567cb1a41 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -15,6 +15,37 @@ void make_block_local_offload(OffloadedStmt *offload, if (offload->task_type != OffloadedStmt::TaskType::struct_for) return; + bool is_bls_applicable = + offload->mem_access_opt.get_snodes_with_flag(SNodeAccessFlag::block_local) + .size() > 0; + if (!is_bls_applicable) { + return; + } + + /* + [TensorType TODO #2] + In general, BLS is trying to analyze and replace load/store of + loop-specific GlobalPtrStmt(..., index) with load/store of a cross-loop + BlockLocalPtrStmt. This requires heavy analysis upon depencencies between + index of GlobalPtrStmt and the loop index. + + In case where GlobalPtrStmt's index being TensorType and stored in an + AllocaStmt, the analysis will fail due to the complicity of address + aliasing. Therefore we apply scalarize here to leverage this analysis + + [Example] + $1 = loop $0 index 0 + <[Tensor (1) i32]> $3 = [$1] + ... + <[Tensor (1) i32]> $12 = alloca + <[Tensor (1) i32]> $13 : local store [$12 <- $3] + <*i32> $14 = shift ptr [$12 + $4] + $15 = local load [$14] + <*i32> $16 = global ptr [S5place], index [$15] activate=true + */ + irpass::scalarize(offload); + irpass::full_simplify(offload, config, {false, /*autodiff_enabled*/ false}); + bool debug = config.debug; auto pads = irpass::initialize_scratch_pad(offload); From f4a45c08723bdb345f0f5703939539b866d5e683 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 11:15:28 +0800 Subject: [PATCH 02/11] bug fix --- taichi/ir/transforms.h | 2 +- taichi/transforms/compile_to_offloads.cpp | 20 +++++----- taichi/transforms/make_block_local.cpp | 25 ++++++++++++- taichi/transforms/scalarize.cpp | 45 +++++++++++++++-------- 4 files changed, 64 insertions(+), 28 deletions(-) diff --git a/taichi/ir/transforms.h b/taichi/ir/transforms.h index 6bd06cb6df9ab..95b4d29cc5764 100644 --- a/taichi/ir/transforms.h +++ b/taichi/ir/transforms.h @@ -30,7 +30,7 @@ namespace irpass { void re_id(IRNode *root); void flag_access(IRNode *root); void eliminate_immutable_local_vars(IRNode *root); -void scalarize(IRNode *root); +bool scalarize(IRNode *root); void vectorize_half2(IRNode *root); void lower_matrix_ptr(IRNode *root); bool die(IRNode *root); diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 5a95b78652ff6..83acbcd3aa9c3 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -232,11 +232,11 @@ void offload_to_executable(IRNode *ir, } if (config.real_matrix_scalarize) { - irpass::scalarize(ir); - - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); - print("Scalarized"); + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); + print("Scalarized"); + } } if (is_extension_supported(config.arch, Extension::mesh)) { @@ -356,11 +356,11 @@ void compile_function(IRNode *ir, } if (config.real_matrix_scalarize) { - irpass::scalarize(ir); - - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::die(ir); - print("Scalarized"); + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::die(ir); + print("Scalarized"); + } } irpass::lower_access(ir, config, {{}, true}); diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index c569567cb1a41..f43a790ed7bd4 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -9,6 +9,24 @@ namespace taichi::lang { namespace { +std::function +make_pass_printer(bool verbose, const std::string &kernel_name, IRNode *ir) { + if (!verbose) { + return [](const std::string &) {}; + } + return [ir, kernel_name](const std::string &pass) { + TI_INFO("[{}] {}:", kernel_name, pass); + std::cout << std::flush; + irpass::re_id(ir); + irpass::print(ir); + std::cout << std::flush; + }; +} + +} // namespace + +namespace { + void make_block_local_offload(OffloadedStmt *offload, const CompileConfig &config, const std::string &kernel_name) { @@ -22,6 +40,8 @@ void make_block_local_offload(OffloadedStmt *offload, return; } + auto print = make_pass_printer(true, "asdasdasd", offload); + /* [TensorType TODO #2] In general, BLS is trying to analyze and replace load/store of @@ -43,8 +63,9 @@ void make_block_local_offload(OffloadedStmt *offload, $15 = local load [$14] <*i32> $16 = global ptr [S5place], index [$15] activate=true */ - irpass::scalarize(offload); - irpass::full_simplify(offload, config, {false, /*autodiff_enabled*/ false}); + if (irpass::scalarize(offload)) { + irpass::full_simplify(offload, config, {false, /*autodiff_enabled*/ false}); + } bool debug = config.debug; diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index f1526ff14130b..d4719011e4f1f 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -18,9 +18,6 @@ class Scalarize : public BasicStmtVisitor { DelayedIRModifier delayed_modifier_; explicit Scalarize(IRNode *node) : immediate_modifier_(node) { - node->accept(this); - - delayed_modifier_.modify_ir(); } /* @@ -841,6 +838,12 @@ class Scalarize : public BasicStmtVisitor { } } + static bool run(IRNode *node) { + Scalarize pass(node); + node->accept(&pass); + return pass.delayed_modifier_.modify_ir(); + } + private: using BasicStmtVisitor::visit; std::unordered_map> scalarized_ad_stack_map_; @@ -898,9 +901,6 @@ class ScalarizePointers : public BasicStmtVisitor { IRNode *node, const std::unordered_set &scalarizable_allocas) : immediate_modifier_(node), scalarizable_allocas_(scalarizable_allocas) { - node->accept(this); - - delayed_modifier_.modify_ir(); } /* @@ -1041,6 +1041,13 @@ class ScalarizePointers : public BasicStmtVisitor { } } + static bool run(IRNode *node, + const std::unordered_set &scalarizable_allocas) { + ScalarizePointers pass(node, scalarizable_allocas); + node->accept(&pass); + return pass.delayed_modifier_.modify_ir(); + } + private: using BasicStmtVisitor::visit; }; @@ -1086,8 +1093,6 @@ class ExtractLocalPointers : public BasicStmtVisitor { TI_ASSERT(root->is()); top_level_ = root->as(); } - root->accept(this); - delayed_modifier_.modify_ir(); } void visit(OffloadedStmt *stmt) override { @@ -1124,6 +1129,12 @@ class ExtractLocalPointers : public BasicStmtVisitor { } } + static bool run(IRNode *node) { + ExtractLocalPointers pass(node); + node->accept(&pass); + return pass.delayed_modifier_.modify_ir(); + } + private: using BasicStmtVisitor::visit; }; @@ -1172,22 +1183,26 @@ class MergeExternalAndMatrixPtr : public BasicStmtVisitor { } } - static void run(IRNode *node) { + static bool run(IRNode *node) { MergeExternalAndMatrixPtr pass; node->accept(&pass); - pass.modifier_.modify_ir(); + return pass.modifier_.modify_ir(); } }; namespace irpass { -void scalarize(IRNode *root) { +bool scalarize(IRNode *root) { TI_AUTO_PROF; - Scalarize scalarize_pass(root); + bool modified = false; + + modified = Scalarize::run(root); auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root); - ScalarizePointers scalarize_pointers_pass(root, scalarizable_allocas); - ExtractLocalPointers extract_pointers_pass(root); - MergeExternalAndMatrixPtr::run(root); + modified = ScalarizePointers::run(root, scalarizable_allocas); + modified = ExtractLocalPointers::run(root); + modified = MergeExternalAndMatrixPtr::run(root); + + return modified; } } // namespace irpass From d7541410f2247da2225ef06c356fdc72df6e14c9 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 11:17:35 +0800 Subject: [PATCH 03/11] [Lang] Migrate irpass::scalarize() after irpass::lower_access() --- taichi/transforms/compile_to_offloads.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 83acbcd3aa9c3..48b6f086a0f78 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -231,14 +231,6 @@ void offload_to_executable(IRNode *ir, print("Make block local"); } - if (config.real_matrix_scalarize) { - if (irpass::scalarize(ir)) { - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); - print("Scalarized"); - } - } - if (is_extension_supported(config.arch, Extension::mesh)) { irpass::demote_mesh_statements(ir, config, {kernel->get_name()}); print("Demote mesh statements"); @@ -276,6 +268,14 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); } + if (config.real_matrix_scalarize) { + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); + print("Scalarized"); + } + } + irpass::demote_operations(ir, config); print("Operations demoted"); @@ -363,7 +363,7 @@ void compile_function(IRNode *ir, } } - irpass::lower_access(ir, config, {{}, true}); + ipass::lower_access(ir, config, {{}, true}); print("Access lowered"); irpass::analysis::verify(ir); From 660819192b4a0e3719939d6df31bf274e2dcba26 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 11:27:36 +0800 Subject: [PATCH 04/11] bug fix --- taichi/transforms/compile_to_offloads.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 48b6f086a0f78..db8b7f75e87e4 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -363,7 +363,7 @@ void compile_function(IRNode *ir, } } - ipass::lower_access(ir, config, {{}, true}); + irpass::lower_access(ir, config, {{}, true}); print("Access lowered"); irpass::analysis::verify(ir); From 7d6a954836340a351cd41642161139d4fd166ba5 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 11:53:47 +0800 Subject: [PATCH 05/11] bug fix --- taichi/transforms/make_block_local.cpp | 20 -------------------- taichi/transforms/scalarize.cpp | 8 ++++---- 2 files changed, 4 insertions(+), 24 deletions(-) diff --git a/taichi/transforms/make_block_local.cpp b/taichi/transforms/make_block_local.cpp index f43a790ed7bd4..c2e89d6b2805f 100644 --- a/taichi/transforms/make_block_local.cpp +++ b/taichi/transforms/make_block_local.cpp @@ -9,24 +9,6 @@ namespace taichi::lang { namespace { -std::function -make_pass_printer(bool verbose, const std::string &kernel_name, IRNode *ir) { - if (!verbose) { - return [](const std::string &) {}; - } - return [ir, kernel_name](const std::string &pass) { - TI_INFO("[{}] {}:", kernel_name, pass); - std::cout << std::flush; - irpass::re_id(ir); - irpass::print(ir); - std::cout << std::flush; - }; -} - -} // namespace - -namespace { - void make_block_local_offload(OffloadedStmt *offload, const CompileConfig &config, const std::string &kernel_name) { @@ -40,8 +22,6 @@ void make_block_local_offload(OffloadedStmt *offload, return; } - auto print = make_pass_printer(true, "asdasdasd", offload); - /* [TensorType TODO #2] In general, BLS is trying to analyze and replace load/store of diff --git a/taichi/transforms/scalarize.cpp b/taichi/transforms/scalarize.cpp index d4719011e4f1f..7e84c6bd40121 100644 --- a/taichi/transforms/scalarize.cpp +++ b/taichi/transforms/scalarize.cpp @@ -1196,11 +1196,11 @@ bool scalarize(IRNode *root) { TI_AUTO_PROF; bool modified = false; - modified = Scalarize::run(root); + modified |= Scalarize::run(root); auto scalarizable_allocas = GatherScalarizableLocalPointers::run(root); - modified = ScalarizePointers::run(root, scalarizable_allocas); - modified = ExtractLocalPointers::run(root); - modified = MergeExternalAndMatrixPtr::run(root); + modified |= ScalarizePointers::run(root, scalarizable_allocas); + modified |= ExtractLocalPointers::run(root); + modified |= MergeExternalAndMatrixPtr::run(root); return modified; } From b211db53be3cfafbf332ebae9cdc95c007432d30 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 12:37:58 +0800 Subject: [PATCH 06/11] bug fix --- taichi/transforms/compile_to_offloads.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index db8b7f75e87e4..a5e686cad9386 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -271,7 +271,8 @@ void offload_to_executable(IRNode *ir, if (config.real_matrix_scalarize) { if (irpass::scalarize(ir)) { // Remove redundant MatrixInitStmt inserted during scalarization - irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); + irpass::full_simplify(ir, config, + {lower_global_access, /*autodiff_enabled*/ false}); print("Scalarized"); } } From e852b32f2220a39d6ca54ec0dc9d07f096b8978b Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 13:23:44 +0800 Subject: [PATCH 07/11] [Lang] Migrate irpass::scalarize() after irpass::demote_operations() --- taichi/codegen/llvm/codegen_llvm.cpp | 4 +- taichi/transforms/compile_to_offloads.cpp | 6 +- taichi/transforms/demote_operations.cpp | 87 ++++++++--------------- 3 files changed, 34 insertions(+), 63 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4e717e24ba369..03fa92e111edd 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -611,7 +611,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { if (is_real(stmt->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (is_signed(stmt->ret_type)) { + } else if (is_signed(stmt->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -658,7 +658,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); #endif } else if (op == BinaryOpType::bit_sar) { - if (is_signed(stmt->lhs->element_type())) { + if (is_signed(stmt->lhs->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index a5e686cad9386..f1f8742bde5aa 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -268,6 +268,9 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); } + irpass::demote_operations(ir, config); + print("Operations demoted"); + if (config.real_matrix_scalarize) { if (irpass::scalarize(ir)) { // Remove redundant MatrixInitStmt inserted during scalarization @@ -277,9 +280,6 @@ void offload_to_executable(IRNode *ir, } } - irpass::demote_operations(ir, config); - print("Operations demoted"); - irpass::full_simplify(ir, config, {lower_global_access, /*autodiff_enabled*/ false}); print("Simplified IV"); diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 5846d6e24b7c7..6a32d3bf3883d 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -21,6 +21,18 @@ class DemoteOperations : public BasicStmtVisitor { Stmt *rhs) { auto ret = Stmt::make(BinaryOpType::div, lhs, rhs); auto zero = Stmt::make(TypedConstant(0)); + + if (lhs->ret_type->is()) { + int num_elements = lhs->ret_type->cast()->get_num_elements(); + std::vector values(num_elements, zero.get()); + + auto matrix_zero = Stmt::make(values); + matrix_zero->ret_type = lhs->ret_type; + + modifier.insert_before(stmt, std::move(zero)); + zero = std::move(matrix_zero); + } + auto lhs_ltz = Stmt::make(BinaryOpType::cmp_lt, lhs, zero.get()); auto rhs_ltz = @@ -39,6 +51,7 @@ class DemoteOperations : public BasicStmtVisitor { cond12.get(), cond3.get()); auto real_ret = Stmt::make(BinaryOpType::sub, ret.get(), cond.get()); + modifier.insert_before(stmt, std::move(ret)); modifier.insert_before(stmt, std::move(zero)); modifier.insert_before(stmt, std::move(lhs_ltz)); @@ -64,9 +77,11 @@ class DemoteOperations : public BasicStmtVisitor { void visit(BinaryOpStmt *stmt) override { auto lhs = stmt->lhs; auto rhs = stmt->rhs; + + auto lhs_prim_type = lhs->ret_type.get_element_type(); + auto rhs_prim_type = rhs->ret_type.get_element_type(); if (stmt->op_type == BinaryOpType::floordiv) { - if (is_integral(rhs->element_type()) && - is_integral(lhs->element_type())) { + if (is_integral(rhs_prim_type) && is_integral(lhs_prim_type)) { // @ti.func // def ifloordiv(a, b): // r = ti.raw_div(a, b) @@ -96,7 +111,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(real_ret)); modifier.erase(stmt); - } else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) { + } else if (is_real(rhs_prim_type) || is_real(lhs_prim_type)) { // @ti.func // def ffloordiv(a, b): // r = ti.raw_div(a, b) @@ -106,48 +121,6 @@ class DemoteOperations : public BasicStmtVisitor { stmt->replace_usages_with(floor.get()); modifier.insert_before(stmt, std::move(floor)); modifier.erase(stmt); - } else if (lhs->ret_type->is() && - rhs->ret_type->is()) { - bool use_integral = is_integral(lhs->ret_type.get_element_type()) && - is_integral(rhs->ret_type.get_element_type()); - std::vector ret_stmts; - auto lhs_tensor_ty = lhs->ret_type->cast(); - auto rhs_tensor_ty = rhs->ret_type->cast(); - auto lhs_alloca = Stmt::make(lhs_tensor_ty); - auto rhs_alloca = Stmt::make(rhs_tensor_ty); - auto lhs_store = - Stmt::make(lhs_alloca.get(), stmt->lhs); - auto rhs_store = - Stmt::make(rhs_alloca.get(), stmt->rhs); - auto lhs_ptr = lhs_alloca.get(); - auto rhs_ptr = rhs_alloca.get(); - modifier.insert_before(stmt, std::move(lhs_alloca)); - modifier.insert_before(stmt, std::move(rhs_alloca)); - modifier.insert_before(stmt, std::move(lhs_store)); - modifier.insert_before(stmt, std::move(rhs_store)); - for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) { - auto idx = Stmt::make(TypedConstant(i)); - auto lhs_i = Stmt::make(lhs_ptr, idx.get()); - auto rhs_i = Stmt::make(rhs_ptr, idx.get()); - auto lhs_load = Stmt::make(lhs_i.get()); - auto rhs_load = Stmt::make(rhs_i.get()); - auto cur_lhs = lhs_load.get(); - auto cur_rhs = rhs_load.get(); - modifier.insert_before(stmt, std::move(idx)); - modifier.insert_before(stmt, std::move(lhs_i)); - modifier.insert_before(stmt, std::move(rhs_i)); - modifier.insert_before(stmt, std::move(lhs_load)); - modifier.insert_before(stmt, std::move(rhs_load)); - auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs) - : demote_ffloor(stmt, cur_lhs, cur_rhs); - ret_stmts.push_back(ret_i.get()); - modifier.insert_before(stmt, std::move(ret_i)); - } - auto new_matrix = Stmt::make(ret_stmts); - new_matrix->ret_type = stmt->ret_type; - stmt->replace_usages_with(new_matrix.get()); - modifier.insert_before(stmt, std::move(new_matrix)); - modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::bit_shr) { // @ti.func @@ -156,17 +129,15 @@ class DemoteOperations : public BasicStmtVisitor { // shifted = ti.bit_sar(unsigned_a, b) // ret = ti.cast(shifted, ti.iXX) // return ret - TI_ASSERT(is_integral(lhs->element_type()) && - is_integral(rhs->element_type())); + TI_ASSERT(is_integral(lhs_prim_type) && is_integral(rhs_prim_type)); auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); - auto lhs_type = lhs->element_type(); unsigned_cast->as()->cast_type = - is_signed(lhs_type) ? to_unsigned(lhs_type) : lhs_type; + is_signed(lhs_prim_type) ? to_unsigned(lhs_prim_type) : lhs_prim_type; auto shift = Stmt::make(BinaryOpType::bit_sar, unsigned_cast.get(), rhs); auto signed_cast = Stmt::make(UnaryOpType::cast_bits, shift.get()); - signed_cast->as()->cast_type = lhs->element_type(); + signed_cast->as()->cast_type = lhs_prim_type; signed_cast->ret_type = stmt->ret_type; stmt->replace_usages_with(signed_cast.get()); modifier.insert_before(stmt, std::move(unsigned_cast)); @@ -174,7 +145,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(signed_cast)); modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::pow && - is_integral(rhs->element_type())) { + is_integral(rhs_prim_type)) { // @ti.func // def pow(lhs, rhs): // a = lhs @@ -189,14 +160,14 @@ class DemoteOperations : public BasicStmtVisitor { // result = 1 / result # for real lhs // return result IRBuilder builder; - auto one_lhs = builder.get_constant(lhs->element_type(), 1); - auto one_rhs = builder.get_constant(rhs->element_type(), 1); - auto zero_rhs = builder.get_constant(rhs->element_type(), 0); - auto a = builder.create_local_var(lhs->element_type()); + auto one_lhs = builder.get_constant(lhs_prim_type, 1); + auto one_rhs = builder.get_constant(rhs_prim_type, 1); + auto zero_rhs = builder.get_constant(rhs_prim_type, 0); + auto a = builder.create_local_var(lhs_prim_type); builder.create_local_store(a, lhs); - auto b = builder.create_local_var(rhs->element_type()); + auto b = builder.create_local_var(rhs_prim_type); builder.create_local_store(b, builder.create_abs(rhs)); - auto result = builder.create_local_var(lhs->element_type()); + auto result = builder.create_local_var(lhs_prim_type); builder.create_local_store(result, one_lhs); auto loop = builder.create_while_true(); { @@ -222,7 +193,7 @@ class DemoteOperations : public BasicStmtVisitor { auto new_b = builder.create_sar(current_b, one_rhs); builder.create_local_store(b, new_b); } - if (is_real(lhs->element_type())) { + if (is_real(lhs_prim_type)) { auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); { auto _ = builder.get_if_guard(if_stmt, true); From 332931af6bb1b5807e69537bd8500055b985ad3a Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 15:21:17 +0800 Subject: [PATCH 08/11] bug fix --- taichi/ir/ir_builder.h | 3 +- taichi/transforms/demote_operations.cpp | 180 +++++++++++++++++------- 2 files changed, 131 insertions(+), 52 deletions(-) diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 5646d4665393e..f78cd9f274034 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -2,6 +2,7 @@ #include "taichi/ir/ir.h" #include "taichi/ir/mesh.h" +#include "taichi/ir/statements.h" namespace taichi::lang { @@ -137,7 +138,7 @@ class IRBuilder { ConstStmt *get_float64(float64 value); template - ConstStmt *get_constant(DataType dt, const T &value) { + Stmt *get_constant(DataType dt, const T &value) { return insert(Stmt::make_typed(TypedConstant(dt, value))); } diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 6a32d3bf3883d..27471de5ef257 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -16,6 +16,117 @@ class DemoteOperations : public BasicStmtVisitor { DemoteOperations() { } + Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs) { + auto lhs_type = lhs->ret_type.get_element_type(); + auto rhs_type = rhs->ret_type.get_element_type(); + + auto one_lhs = builder.get_constant(lhs_type, 1); + auto one_rhs = builder.get_constant(rhs_type, 1); + auto zero_rhs = builder.get_constant(rhs_type, 0); + auto a = builder.create_local_var(lhs_type); + builder.create_local_store(a, lhs); + auto b = builder.create_local_var(rhs_type); + builder.create_local_store(b, builder.create_abs(rhs)); + auto result = builder.create_local_var(lhs_type); + builder.create_local_store(result, one_lhs); + auto loop = builder.create_while_true(); + { + auto loop_guard = builder.get_loop_guard(loop); + auto current_a = builder.create_local_load(a); + auto current_b = builder.create_local_load(b); + auto if_stmt = + builder.create_if(builder.create_cmp_le(current_b, zero_rhs)); + { + auto _ = builder.get_if_guard(if_stmt, true); + builder.create_break(); + } + auto bit_and = builder.create_and(current_b, one_rhs); + if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs)); + { + auto _ = builder.get_if_guard(if_stmt, true); + auto current_result = builder.create_local_load(result); + auto new_result = builder.create_mul(current_result, current_a); + builder.create_local_store(result, new_result); + } + auto new_a = builder.create_mul(current_a, current_a); + builder.create_local_store(a, new_a); + auto new_b = builder.create_sar(current_b, one_rhs); + builder.create_local_store(b, new_b); + } + if (is_real(lhs_type)) { + auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); + { + auto _ = builder.get_if_guard(if_stmt, true); + auto current_result = builder.create_local_load(result); + auto new_result = builder.create_div(one_lhs, current_result); + builder.create_local_store(result, new_result); + } + } + auto final_result = builder.create_local_load(result); + return final_result; + } + + void transform_pow_op_scalar(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { + IRBuilder builder; + + auto final_result = transform_pow_op_impl(builder, lhs, rhs); + + stmt->replace_usages_with(final_result); + modifier.insert_before( + stmt, VecStatement(std::move(builder.extract_ir()->statements))); + modifier.erase(stmt); + } + + void transform_pow_op_tensor(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { + std::vector ret_stmts; + auto lhs_tensor_ty = lhs->ret_type->cast(); + auto rhs_tensor_ty = rhs->ret_type->cast(); + + auto lhs_prim_type = lhs_tensor_ty->get_element_type(); + auto rhs_prim_type = rhs_tensor_ty->get_element_type(); + + auto lhs_alloca = Stmt::make(lhs_tensor_ty); + auto rhs_alloca = Stmt::make(rhs_tensor_ty); + auto lhs_store = Stmt::make(lhs_alloca.get(), stmt->lhs); + auto rhs_store = Stmt::make(rhs_alloca.get(), stmt->rhs); + auto lhs_ptr = lhs_alloca.get(); + auto rhs_ptr = rhs_alloca.get(); + modifier.insert_before(stmt, std::move(lhs_alloca)); + modifier.insert_before(stmt, std::move(rhs_alloca)); + modifier.insert_before(stmt, std::move(lhs_store)); + modifier.insert_before(stmt, std::move(rhs_store)); + for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) { + auto idx = Stmt::make(TypedConstant(i)); + auto lhs_i = Stmt::make(lhs_ptr, idx.get()); + auto rhs_i = Stmt::make(rhs_ptr, idx.get()); + auto lhs_load = Stmt::make(lhs_i.get()); + lhs_load->ret_type = lhs_prim_type; + + auto rhs_load = Stmt::make(rhs_i.get()); + rhs_load->ret_type = rhs_prim_type; + + auto cur_lhs = lhs_load.get(); + auto cur_rhs = rhs_load.get(); + modifier.insert_before(stmt, std::move(idx)); + modifier.insert_before(stmt, std::move(lhs_i)); + modifier.insert_before(stmt, std::move(rhs_i)); + modifier.insert_before(stmt, std::move(lhs_load)); + modifier.insert_before(stmt, std::move(rhs_load)); + + IRBuilder builder; + auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs); + + modifier.insert_before( + stmt, VecStatement(std::move(builder.extract_ir()->statements))); + ret_stmts.push_back(cur_result); + } + auto new_matrix = Stmt::make(ret_stmts); + new_matrix->ret_type = stmt->ret_type; + stmt->replace_usages_with(new_matrix.get()); + modifier.insert_before(stmt, std::move(new_matrix)); + modifier.erase(stmt); + } + std::unique_ptr demote_ifloordiv(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { @@ -78,8 +189,11 @@ class DemoteOperations : public BasicStmtVisitor { auto lhs = stmt->lhs; auto rhs = stmt->rhs; - auto lhs_prim_type = lhs->ret_type.get_element_type(); - auto rhs_prim_type = rhs->ret_type.get_element_type(); + auto lhs_type = lhs->ret_type; + auto rhs_type = rhs->ret_type; + + auto lhs_prim_type = lhs_type.get_element_type(); + auto rhs_prim_type = rhs_type.get_element_type(); if (stmt->op_type == BinaryOpType::floordiv) { if (is_integral(rhs_prim_type) && is_integral(lhs_prim_type)) { // @ti.func @@ -144,8 +258,11 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(shift)); modifier.insert_before(stmt, std::move(signed_cast)); modifier.erase(stmt); - } else if (stmt->op_type == BinaryOpType::pow && - is_integral(rhs_prim_type)) { + } else if (stmt->op_type == BinaryOpType::pow) { + // There's no direct support for Power operation in LLVM / SpirV IR. + // We need to manually transform it to make it work. + + // [Transform] // @ti.func // def pow(lhs, rhs): // a = lhs @@ -159,54 +276,15 @@ class DemoteOperations : public BasicStmtVisitor { // if rhs < 0: # for real lhs // result = 1 / result # for real lhs // return result - IRBuilder builder; - auto one_lhs = builder.get_constant(lhs_prim_type, 1); - auto one_rhs = builder.get_constant(rhs_prim_type, 1); - auto zero_rhs = builder.get_constant(rhs_prim_type, 0); - auto a = builder.create_local_var(lhs_prim_type); - builder.create_local_store(a, lhs); - auto b = builder.create_local_var(rhs_prim_type); - builder.create_local_store(b, builder.create_abs(rhs)); - auto result = builder.create_local_var(lhs_prim_type); - builder.create_local_store(result, one_lhs); - auto loop = builder.create_while_true(); - { - auto loop_guard = builder.get_loop_guard(loop); - auto current_a = builder.create_local_load(a); - auto current_b = builder.create_local_load(b); - auto if_stmt = - builder.create_if(builder.create_cmp_le(current_b, zero_rhs)); - { - auto _ = builder.get_if_guard(if_stmt, true); - builder.create_break(); - } - auto bit_and = builder.create_and(current_b, one_rhs); - if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs)); - { - auto _ = builder.get_if_guard(if_stmt, true); - auto current_result = builder.create_local_load(result); - auto new_result = builder.create_mul(current_result, current_a); - builder.create_local_store(result, new_result); - } - auto new_a = builder.create_mul(current_a, current_a); - builder.create_local_store(a, new_a); - auto new_b = builder.create_sar(current_b, one_rhs); - builder.create_local_store(b, new_b); + if (is_integral(rhs_type)) { + transform_pow_op_scalar(stmt, lhs, rhs); + } else if (rhs_type->is() && lhs_type->is() && + is_integral(rhs_type.get_element_type())) { + // For Power with TensorType'd operands, since IfStmt and WhileStmt + // isn't compatible with TensorType'd condition statement, + // we have to perform immediate scalarization with help from AllocaStmt. + transform_pow_op_tensor(stmt, lhs, rhs); } - if (is_real(lhs_prim_type)) { - auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); - { - auto _ = builder.get_if_guard(if_stmt, true); - auto current_result = builder.create_local_load(result); - auto new_result = builder.create_div(one_lhs, current_result); - builder.create_local_store(result, new_result); - } - } - auto final_result = builder.create_local_load(result); - stmt->replace_usages_with(final_result); - modifier.insert_before( - stmt, VecStatement(std::move(builder.extract_ir()->statements))); - modifier.erase(stmt); } } From 99130d01066c01f1e90597f25558aefaf965bf3e Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 15:22:16 +0800 Subject: [PATCH 09/11] code adjustment --- taichi/transforms/compile_to_offloads.cpp | 8 -------- 1 file changed, 8 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index 6c3a6acda3030..a5e686cad9386 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -231,14 +231,6 @@ void offload_to_executable(IRNode *ir, print("Make block local"); } - if (config.real_matrix_scalarize) { - if (irpass::scalarize(ir)) { - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::full_simplify(ir, config, {false, /*autodiff_enabled*/ false}); - print("Scalarized"); - } - } - if (is_extension_supported(config.arch, Extension::mesh)) { irpass::demote_mesh_statements(ir, config, {kernel->get_name()}); print("Demote mesh statements"); From 5e824fde721781dfdaef93d68c23ebaf8c576232 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 15:41:04 +0800 Subject: [PATCH 10/11] [Lang] Migrate irpass::scalarize() after optimize_bit_struct_stores & determine_ad_stack_size --- taichi/transforms/compile_to_offloads.cpp | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index f1f8742bde5aa..b7c83b6790d83 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -271,15 +271,6 @@ void offload_to_executable(IRNode *ir, irpass::demote_operations(ir, config); print("Operations demoted"); - if (config.real_matrix_scalarize) { - if (irpass::scalarize(ir)) { - // Remove redundant MatrixInitStmt inserted during scalarization - irpass::full_simplify(ir, config, - {lower_global_access, /*autodiff_enabled*/ false}); - print("Scalarized"); - } - } - irpass::full_simplify(ir, config, {lower_global_access, /*autodiff_enabled*/ false}); print("Simplified IV"); @@ -294,6 +285,15 @@ void offload_to_executable(IRNode *ir, print("Bit struct stores optimized"); } + if (config.real_matrix_scalarize) { + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::full_simplify(ir, config, + {lower_global_access, /*autodiff_enabled*/ false}); + print("Scalarized"); + } + } + if (config.arch == Arch::cuda && config.half2_vectorization && !get_custom_cuda_library_path().empty()) { irpass::vectorize_half2(ir); From b99b2d6154570e8488f83ab8862ed22efa5a3282 Mon Sep 17 00:00:00 2001 From: jim19930609 Date: Mon, 29 May 2023 16:21:02 +0800 Subject: [PATCH 11/11] bug fix --- taichi/ir/control_flow_graph.cpp | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/taichi/ir/control_flow_graph.cpp b/taichi/ir/control_flow_graph.cpp index cc8e9472d4f70..d546900f83724 100644 --- a/taichi/ir/control_flow_graph.cpp +++ b/taichi/ir/control_flow_graph.cpp @@ -384,7 +384,10 @@ void CFGNode::reaching_definition_analysis(bool after_lower_access) { auto data_source_ptrs = irpass::analysis::get_store_destination(stmt); for (auto data_source_ptr : data_source_ptrs) { // stmt provides a data source - if (after_lower_access && !(data_source_ptr->is())) { + if (after_lower_access && + !((data_source_ptr->is() && + data_source_ptr->as()->origin->is()) || + data_source_ptr->is())) { // After lower_access, we only analyze local variables. continue; } @@ -552,6 +555,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { irpass::analysis::get_load_pointers(stmt, true /*get_alias*/); for (auto &load_ptr : load_ptrs) { if (!after_lower_access || + (load_ptr->is() && + load_ptr->as()->origin->is()) || (load_ptr->is() || load_ptr->is())) { // After lower_access, we only analyze local variables and stacks. if (!contain_variable(live_kill, load_ptr)) { @@ -576,6 +581,8 @@ void CFGNode::live_variable_analysis(bool after_lower_access) { } for (auto store_ptr : store_ptrs) { if (!after_lower_access || + (store_ptr->is() && + store_ptr->as()->origin->is()) || (store_ptr->is() || store_ptr->is())) { // After lower_access, we only analyze local variables and stacks. live_kill.insert(store_ptr); @@ -707,6 +714,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { auto store_ptr = *store_ptrs.begin(); if (!after_lower_access || + (store_ptr->is() && + store_ptr->as()->origin->is()) || (store_ptr->is() || store_ptr->is())) { // !may_contain_variable(live_in_this_node, store_ptr): address is not // loaded after this store @@ -806,6 +815,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { auto load_ptr = load_ptrs.begin()[0]; if (!after_lower_access || + (load_ptr->is() && + load_ptr->as()->origin->is()) || (load_ptr->is() || load_ptr->is())) { // live_load_in_this_node[addr]: tracks the // next load to the same address @@ -832,6 +843,8 @@ bool CFGNode::dead_store_elimination(bool after_lower_access) { // Update live_in_this_node for (auto &load_ptr : load_ptrs) { if (!after_lower_access || + (load_ptr->is() && + load_ptr->as()->origin->is()) || (load_ptr->is() || load_ptr->is())) { // Addr is used in this node, so it's live in this node update_container_with_alias(tensor_to_matrix_ptrs_map,