From 67e90841cb85d3df67fac4939d6be714f9fc09fb Mon Sep 17 00:00:00 2001 From: Zhanlue Yang Date: Wed, 31 May 2023 09:07:28 +0800 Subject: [PATCH] [Lang] Migrate irpass::scalarize() after irpass::demote_operations() (#8096) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Issue: # ### Brief Summary ### 🤖 Generated by Copilot at e852b32 This pull request fixes some bugs and improves some optimizations in the LLVM code generation and the IR transformation passes. It introduces a new type `TensorType` and handles it correctly in the `demote_operations` and `codegen_llvm` passes. It also refines the scalarization and block-local optimization of the IR nodes, by applying them selectively, conditionally, and with proper simplification. ### Walkthrough ### 🤖 Generated by Copilot at e852b32 * Fix a bug in codegen_llvm.cpp where the signedness of the return type of a binary division or bit shift right operation was not checked correctly ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313L614-R614), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-3c663c78745adcd3f6a7ac81fe99e628decc3040f292ea1e20ecd4b85a7f4313L661-R661)) * Modify the scalarize function in ir/transforms.h and scalarize.cpp to return a bool value indicating whether the IR node has been modified or not ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-448ac6e85e192a27e5ec7c54cd8a91545dc7c83f62d030eafb9c190383cfe934L33-R33), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L21-L23), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528R841-R846), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1184-R1205)) * Remove the redundant and unnecessary calls to the delayed_modifier_.modify_ir() method in scalarize.cpp, which are replaced by the bool return values of the scalarize passes ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L901-L903), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528R1044-R1050), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1089-L1090)) * Add the conditional call to the scalarize function in compile_to_offloads.cpp, where it is applied to each offload node or function node that has the block_local flag set, and perform the full simplification or dead instruction elimination if the node has been modified by the scalarize function ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bR274-R282), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL359-R364)) * Remove the call to the scalarize function in compile_to_offloads.cpp, where it was applied to the whole IR node before the offload_to_executable function, which is unnecessary and could affect the performance or correctness of the code generation ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-8fde186587db97b3bbc8a856e59bc4467b30257335b0fad064b4eebd521a912bL229-L236)) * Add some code to the make_block_local.cpp file, where it checks whether the offload node accesses any block-local SNodes, and applies the scalarize pass and the full simplification if so ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-b2368908e6bad68906f40866b6dece421190dfd428d63788f2f5143b14785a45R18-R49)) * Remove some code from the demote_operations.cpp file, where it handled the case of TensorTypes for power operations or both operands, which is unnecessary and incorrect because the scalarize pass already handles the case of TensorTypes ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR54), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL109-L150)) * Add some code to the demote_operations.cpp file, where it handles the case of TensorTypes for the left-hand side of a binary operation, by creating a zero value of the same TensorType and replacing the usage of the original left-hand side with the zero value ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bR24-R35)) * Modify the demote_operations.cpp file, where it handles the case of TensorTypes for floor division or bit shift right operations, by using the get_element_type() method to get the primitive type of the operands and passing them to the is_integral(), is_signed(), or is_real() functions ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL67-R84), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL99-R114), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL159-R140), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL177-R148), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL192-R170), [link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-d217f2b07d4578612dc805b0f01e5dc1883be9acb906b222a8762313cfd0596bL225-R196)) * Modify the static run method of the MergeExternalAndMatrixPtr class in scalarize.cpp to return a bool value indicating whether the IR node has been modified or not ([link](https://github.com/taichi-dev/taichi/pull/8096/files?diff=unified&w=0#diff-97b0d9ab204b703802b3b5d04d036d30f66b34b726128216faf0d8a2a8564528L1175-R1189)) --- taichi/codegen/llvm/codegen_llvm.cpp | 4 +- taichi/ir/ir_builder.h | 3 +- taichi/transforms/compile_to_offloads.cpp | 9 + taichi/transforms/demote_operations.cpp | 247 +++++++++++++--------- 4 files changed, 161 insertions(+), 102 deletions(-) diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4e717e24ba369..03fa92e111edd 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -611,7 +611,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { if (is_real(stmt->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (is_signed(stmt->ret_type)) { + } else if (is_signed(stmt->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -658,7 +658,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); #endif } else if (op == BinaryOpType::bit_sar) { - if (is_signed(stmt->lhs->element_type())) { + if (is_signed(stmt->lhs->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { diff --git a/taichi/ir/ir_builder.h b/taichi/ir/ir_builder.h index 5646d4665393e..f78cd9f274034 100644 --- a/taichi/ir/ir_builder.h +++ b/taichi/ir/ir_builder.h @@ -2,6 +2,7 @@ #include "taichi/ir/ir.h" #include "taichi/ir/mesh.h" +#include "taichi/ir/statements.h" namespace taichi::lang { @@ -137,7 +138,7 @@ class IRBuilder { ConstStmt *get_float64(float64 value); template - ConstStmt *get_constant(DataType dt, const T &value) { + Stmt *get_constant(DataType dt, const T &value) { return insert(Stmt::make_typed(TypedConstant(dt, value))); } diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index a5e686cad9386..bc1bd85a160d8 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -280,6 +280,15 @@ void offload_to_executable(IRNode *ir, irpass::demote_operations(ir, config); print("Operations demoted"); + if (config.real_matrix_scalarize) { + if (irpass::scalarize(ir)) { + // Remove redundant MatrixInitStmt inserted during scalarization + irpass::full_simplify(ir, config, + {lower_global_access, /*autodiff_enabled*/ false}); + print("Scalarized"); + } + } + irpass::full_simplify(ir, config, {lower_global_access, /*autodiff_enabled*/ false}); print("Simplified IV"); diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 5846d6e24b7c7..27471de5ef257 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -16,11 +16,134 @@ class DemoteOperations : public BasicStmtVisitor { DemoteOperations() { } + Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs) { + auto lhs_type = lhs->ret_type.get_element_type(); + auto rhs_type = rhs->ret_type.get_element_type(); + + auto one_lhs = builder.get_constant(lhs_type, 1); + auto one_rhs = builder.get_constant(rhs_type, 1); + auto zero_rhs = builder.get_constant(rhs_type, 0); + auto a = builder.create_local_var(lhs_type); + builder.create_local_store(a, lhs); + auto b = builder.create_local_var(rhs_type); + builder.create_local_store(b, builder.create_abs(rhs)); + auto result = builder.create_local_var(lhs_type); + builder.create_local_store(result, one_lhs); + auto loop = builder.create_while_true(); + { + auto loop_guard = builder.get_loop_guard(loop); + auto current_a = builder.create_local_load(a); + auto current_b = builder.create_local_load(b); + auto if_stmt = + builder.create_if(builder.create_cmp_le(current_b, zero_rhs)); + { + auto _ = builder.get_if_guard(if_stmt, true); + builder.create_break(); + } + auto bit_and = builder.create_and(current_b, one_rhs); + if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs)); + { + auto _ = builder.get_if_guard(if_stmt, true); + auto current_result = builder.create_local_load(result); + auto new_result = builder.create_mul(current_result, current_a); + builder.create_local_store(result, new_result); + } + auto new_a = builder.create_mul(current_a, current_a); + builder.create_local_store(a, new_a); + auto new_b = builder.create_sar(current_b, one_rhs); + builder.create_local_store(b, new_b); + } + if (is_real(lhs_type)) { + auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); + { + auto _ = builder.get_if_guard(if_stmt, true); + auto current_result = builder.create_local_load(result); + auto new_result = builder.create_div(one_lhs, current_result); + builder.create_local_store(result, new_result); + } + } + auto final_result = builder.create_local_load(result); + return final_result; + } + + void transform_pow_op_scalar(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { + IRBuilder builder; + + auto final_result = transform_pow_op_impl(builder, lhs, rhs); + + stmt->replace_usages_with(final_result); + modifier.insert_before( + stmt, VecStatement(std::move(builder.extract_ir()->statements))); + modifier.erase(stmt); + } + + void transform_pow_op_tensor(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { + std::vector ret_stmts; + auto lhs_tensor_ty = lhs->ret_type->cast(); + auto rhs_tensor_ty = rhs->ret_type->cast(); + + auto lhs_prim_type = lhs_tensor_ty->get_element_type(); + auto rhs_prim_type = rhs_tensor_ty->get_element_type(); + + auto lhs_alloca = Stmt::make(lhs_tensor_ty); + auto rhs_alloca = Stmt::make(rhs_tensor_ty); + auto lhs_store = Stmt::make(lhs_alloca.get(), stmt->lhs); + auto rhs_store = Stmt::make(rhs_alloca.get(), stmt->rhs); + auto lhs_ptr = lhs_alloca.get(); + auto rhs_ptr = rhs_alloca.get(); + modifier.insert_before(stmt, std::move(lhs_alloca)); + modifier.insert_before(stmt, std::move(rhs_alloca)); + modifier.insert_before(stmt, std::move(lhs_store)); + modifier.insert_before(stmt, std::move(rhs_store)); + for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) { + auto idx = Stmt::make(TypedConstant(i)); + auto lhs_i = Stmt::make(lhs_ptr, idx.get()); + auto rhs_i = Stmt::make(rhs_ptr, idx.get()); + auto lhs_load = Stmt::make(lhs_i.get()); + lhs_load->ret_type = lhs_prim_type; + + auto rhs_load = Stmt::make(rhs_i.get()); + rhs_load->ret_type = rhs_prim_type; + + auto cur_lhs = lhs_load.get(); + auto cur_rhs = rhs_load.get(); + modifier.insert_before(stmt, std::move(idx)); + modifier.insert_before(stmt, std::move(lhs_i)); + modifier.insert_before(stmt, std::move(rhs_i)); + modifier.insert_before(stmt, std::move(lhs_load)); + modifier.insert_before(stmt, std::move(rhs_load)); + + IRBuilder builder; + auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs); + + modifier.insert_before( + stmt, VecStatement(std::move(builder.extract_ir()->statements))); + ret_stmts.push_back(cur_result); + } + auto new_matrix = Stmt::make(ret_stmts); + new_matrix->ret_type = stmt->ret_type; + stmt->replace_usages_with(new_matrix.get()); + modifier.insert_before(stmt, std::move(new_matrix)); + modifier.erase(stmt); + } + std::unique_ptr demote_ifloordiv(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) { auto ret = Stmt::make(BinaryOpType::div, lhs, rhs); auto zero = Stmt::make(TypedConstant(0)); + + if (lhs->ret_type->is()) { + int num_elements = lhs->ret_type->cast()->get_num_elements(); + std::vector values(num_elements, zero.get()); + + auto matrix_zero = Stmt::make(values); + matrix_zero->ret_type = lhs->ret_type; + + modifier.insert_before(stmt, std::move(zero)); + zero = std::move(matrix_zero); + } + auto lhs_ltz = Stmt::make(BinaryOpType::cmp_lt, lhs, zero.get()); auto rhs_ltz = @@ -39,6 +162,7 @@ class DemoteOperations : public BasicStmtVisitor { cond12.get(), cond3.get()); auto real_ret = Stmt::make(BinaryOpType::sub, ret.get(), cond.get()); + modifier.insert_before(stmt, std::move(ret)); modifier.insert_before(stmt, std::move(zero)); modifier.insert_before(stmt, std::move(lhs_ltz)); @@ -64,9 +188,14 @@ class DemoteOperations : public BasicStmtVisitor { void visit(BinaryOpStmt *stmt) override { auto lhs = stmt->lhs; auto rhs = stmt->rhs; + + auto lhs_type = lhs->ret_type; + auto rhs_type = rhs->ret_type; + + auto lhs_prim_type = lhs_type.get_element_type(); + auto rhs_prim_type = rhs_type.get_element_type(); if (stmt->op_type == BinaryOpType::floordiv) { - if (is_integral(rhs->element_type()) && - is_integral(lhs->element_type())) { + if (is_integral(rhs_prim_type) && is_integral(lhs_prim_type)) { // @ti.func // def ifloordiv(a, b): // r = ti.raw_div(a, b) @@ -96,7 +225,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(real_ret)); modifier.erase(stmt); - } else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) { + } else if (is_real(rhs_prim_type) || is_real(lhs_prim_type)) { // @ti.func // def ffloordiv(a, b): // r = ti.raw_div(a, b) @@ -106,48 +235,6 @@ class DemoteOperations : public BasicStmtVisitor { stmt->replace_usages_with(floor.get()); modifier.insert_before(stmt, std::move(floor)); modifier.erase(stmt); - } else if (lhs->ret_type->is() && - rhs->ret_type->is()) { - bool use_integral = is_integral(lhs->ret_type.get_element_type()) && - is_integral(rhs->ret_type.get_element_type()); - std::vector ret_stmts; - auto lhs_tensor_ty = lhs->ret_type->cast(); - auto rhs_tensor_ty = rhs->ret_type->cast(); - auto lhs_alloca = Stmt::make(lhs_tensor_ty); - auto rhs_alloca = Stmt::make(rhs_tensor_ty); - auto lhs_store = - Stmt::make(lhs_alloca.get(), stmt->lhs); - auto rhs_store = - Stmt::make(rhs_alloca.get(), stmt->rhs); - auto lhs_ptr = lhs_alloca.get(); - auto rhs_ptr = rhs_alloca.get(); - modifier.insert_before(stmt, std::move(lhs_alloca)); - modifier.insert_before(stmt, std::move(rhs_alloca)); - modifier.insert_before(stmt, std::move(lhs_store)); - modifier.insert_before(stmt, std::move(rhs_store)); - for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) { - auto idx = Stmt::make(TypedConstant(i)); - auto lhs_i = Stmt::make(lhs_ptr, idx.get()); - auto rhs_i = Stmt::make(rhs_ptr, idx.get()); - auto lhs_load = Stmt::make(lhs_i.get()); - auto rhs_load = Stmt::make(rhs_i.get()); - auto cur_lhs = lhs_load.get(); - auto cur_rhs = rhs_load.get(); - modifier.insert_before(stmt, std::move(idx)); - modifier.insert_before(stmt, std::move(lhs_i)); - modifier.insert_before(stmt, std::move(rhs_i)); - modifier.insert_before(stmt, std::move(lhs_load)); - modifier.insert_before(stmt, std::move(rhs_load)); - auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs) - : demote_ffloor(stmt, cur_lhs, cur_rhs); - ret_stmts.push_back(ret_i.get()); - modifier.insert_before(stmt, std::move(ret_i)); - } - auto new_matrix = Stmt::make(ret_stmts); - new_matrix->ret_type = stmt->ret_type; - stmt->replace_usages_with(new_matrix.get()); - modifier.insert_before(stmt, std::move(new_matrix)); - modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::bit_shr) { // @ti.func @@ -156,25 +243,26 @@ class DemoteOperations : public BasicStmtVisitor { // shifted = ti.bit_sar(unsigned_a, b) // ret = ti.cast(shifted, ti.iXX) // return ret - TI_ASSERT(is_integral(lhs->element_type()) && - is_integral(rhs->element_type())); + TI_ASSERT(is_integral(lhs_prim_type) && is_integral(rhs_prim_type)); auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); - auto lhs_type = lhs->element_type(); unsigned_cast->as()->cast_type = - is_signed(lhs_type) ? to_unsigned(lhs_type) : lhs_type; + is_signed(lhs_prim_type) ? to_unsigned(lhs_prim_type) : lhs_prim_type; auto shift = Stmt::make(BinaryOpType::bit_sar, unsigned_cast.get(), rhs); auto signed_cast = Stmt::make(UnaryOpType::cast_bits, shift.get()); - signed_cast->as()->cast_type = lhs->element_type(); + signed_cast->as()->cast_type = lhs_prim_type; signed_cast->ret_type = stmt->ret_type; stmt->replace_usages_with(signed_cast.get()); modifier.insert_before(stmt, std::move(unsigned_cast)); modifier.insert_before(stmt, std::move(shift)); modifier.insert_before(stmt, std::move(signed_cast)); modifier.erase(stmt); - } else if (stmt->op_type == BinaryOpType::pow && - is_integral(rhs->element_type())) { + } else if (stmt->op_type == BinaryOpType::pow) { + // There's no direct support for Power operation in LLVM / SpirV IR. + // We need to manually transform it to make it work. + + // [Transform] // @ti.func // def pow(lhs, rhs): // a = lhs @@ -188,54 +276,15 @@ class DemoteOperations : public BasicStmtVisitor { // if rhs < 0: # for real lhs // result = 1 / result # for real lhs // return result - IRBuilder builder; - auto one_lhs = builder.get_constant(lhs->element_type(), 1); - auto one_rhs = builder.get_constant(rhs->element_type(), 1); - auto zero_rhs = builder.get_constant(rhs->element_type(), 0); - auto a = builder.create_local_var(lhs->element_type()); - builder.create_local_store(a, lhs); - auto b = builder.create_local_var(rhs->element_type()); - builder.create_local_store(b, builder.create_abs(rhs)); - auto result = builder.create_local_var(lhs->element_type()); - builder.create_local_store(result, one_lhs); - auto loop = builder.create_while_true(); - { - auto loop_guard = builder.get_loop_guard(loop); - auto current_a = builder.create_local_load(a); - auto current_b = builder.create_local_load(b); - auto if_stmt = - builder.create_if(builder.create_cmp_le(current_b, zero_rhs)); - { - auto _ = builder.get_if_guard(if_stmt, true); - builder.create_break(); - } - auto bit_and = builder.create_and(current_b, one_rhs); - if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs)); - { - auto _ = builder.get_if_guard(if_stmt, true); - auto current_result = builder.create_local_load(result); - auto new_result = builder.create_mul(current_result, current_a); - builder.create_local_store(result, new_result); - } - auto new_a = builder.create_mul(current_a, current_a); - builder.create_local_store(a, new_a); - auto new_b = builder.create_sar(current_b, one_rhs); - builder.create_local_store(b, new_b); - } - if (is_real(lhs->element_type())) { - auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); - { - auto _ = builder.get_if_guard(if_stmt, true); - auto current_result = builder.create_local_load(result); - auto new_result = builder.create_div(one_lhs, current_result); - builder.create_local_store(result, new_result); - } + if (is_integral(rhs_type)) { + transform_pow_op_scalar(stmt, lhs, rhs); + } else if (rhs_type->is() && lhs_type->is() && + is_integral(rhs_type.get_element_type())) { + // For Power with TensorType'd operands, since IfStmt and WhileStmt + // isn't compatible with TensorType'd condition statement, + // we have to perform immediate scalarization with help from AllocaStmt. + transform_pow_op_tensor(stmt, lhs, rhs); } - auto final_result = builder.create_local_load(result); - stmt->replace_usages_with(final_result); - modifier.insert_before( - stmt, VecStatement(std::move(builder.extract_ir()->statements))); - modifier.erase(stmt); } }