diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 4e717e24ba369..03fa92e111edd 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -611,7 +611,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { if (is_real(stmt->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); - } else if (is_signed(stmt->ret_type)) { + } else if (is_signed(stmt->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { @@ -658,7 +658,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); #endif } else if (op == BinaryOpType::bit_sar) { - if (is_signed(stmt->lhs->element_type())) { + if (is_signed(stmt->lhs->ret_type.get_element_type())) { llvm_val[stmt] = builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else { diff --git a/taichi/transforms/compile_to_offloads.cpp b/taichi/transforms/compile_to_offloads.cpp index a5e686cad9386..f1f8742bde5aa 100644 --- a/taichi/transforms/compile_to_offloads.cpp +++ b/taichi/transforms/compile_to_offloads.cpp @@ -268,6 +268,9 @@ void offload_to_executable(IRNode *ir, irpass::analysis::verify(ir); } + irpass::demote_operations(ir, config); + print("Operations demoted"); + if (config.real_matrix_scalarize) { if (irpass::scalarize(ir)) { // Remove redundant MatrixInitStmt inserted during scalarization @@ -277,9 +280,6 @@ void offload_to_executable(IRNode *ir, } } - irpass::demote_operations(ir, config); - print("Operations demoted"); - irpass::full_simplify(ir, config, {lower_global_access, /*autodiff_enabled*/ false}); print("Simplified IV"); diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 5846d6e24b7c7..6a32d3bf3883d 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -21,6 +21,18 @@ class DemoteOperations : public BasicStmtVisitor { Stmt *rhs) { auto ret = Stmt::make(BinaryOpType::div, lhs, rhs); auto zero = Stmt::make(TypedConstant(0)); + + if (lhs->ret_type->is()) { + int num_elements = lhs->ret_type->cast()->get_num_elements(); + std::vector values(num_elements, zero.get()); + + auto matrix_zero = Stmt::make(values); + matrix_zero->ret_type = lhs->ret_type; + + modifier.insert_before(stmt, std::move(zero)); + zero = std::move(matrix_zero); + } + auto lhs_ltz = Stmt::make(BinaryOpType::cmp_lt, lhs, zero.get()); auto rhs_ltz = @@ -39,6 +51,7 @@ class DemoteOperations : public BasicStmtVisitor { cond12.get(), cond3.get()); auto real_ret = Stmt::make(BinaryOpType::sub, ret.get(), cond.get()); + modifier.insert_before(stmt, std::move(ret)); modifier.insert_before(stmt, std::move(zero)); modifier.insert_before(stmt, std::move(lhs_ltz)); @@ -64,9 +77,11 @@ class DemoteOperations : public BasicStmtVisitor { void visit(BinaryOpStmt *stmt) override { auto lhs = stmt->lhs; auto rhs = stmt->rhs; + + auto lhs_prim_type = lhs->ret_type.get_element_type(); + auto rhs_prim_type = rhs->ret_type.get_element_type(); if (stmt->op_type == BinaryOpType::floordiv) { - if (is_integral(rhs->element_type()) && - is_integral(lhs->element_type())) { + if (is_integral(rhs_prim_type) && is_integral(lhs_prim_type)) { // @ti.func // def ifloordiv(a, b): // r = ti.raw_div(a, b) @@ -96,7 +111,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(real_ret)); modifier.erase(stmt); - } else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) { + } else if (is_real(rhs_prim_type) || is_real(lhs_prim_type)) { // @ti.func // def ffloordiv(a, b): // r = ti.raw_div(a, b) @@ -106,48 +121,6 @@ class DemoteOperations : public BasicStmtVisitor { stmt->replace_usages_with(floor.get()); modifier.insert_before(stmt, std::move(floor)); modifier.erase(stmt); - } else if (lhs->ret_type->is() && - rhs->ret_type->is()) { - bool use_integral = is_integral(lhs->ret_type.get_element_type()) && - is_integral(rhs->ret_type.get_element_type()); - std::vector ret_stmts; - auto lhs_tensor_ty = lhs->ret_type->cast(); - auto rhs_tensor_ty = rhs->ret_type->cast(); - auto lhs_alloca = Stmt::make(lhs_tensor_ty); - auto rhs_alloca = Stmt::make(rhs_tensor_ty); - auto lhs_store = - Stmt::make(lhs_alloca.get(), stmt->lhs); - auto rhs_store = - Stmt::make(rhs_alloca.get(), stmt->rhs); - auto lhs_ptr = lhs_alloca.get(); - auto rhs_ptr = rhs_alloca.get(); - modifier.insert_before(stmt, std::move(lhs_alloca)); - modifier.insert_before(stmt, std::move(rhs_alloca)); - modifier.insert_before(stmt, std::move(lhs_store)); - modifier.insert_before(stmt, std::move(rhs_store)); - for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) { - auto idx = Stmt::make(TypedConstant(i)); - auto lhs_i = Stmt::make(lhs_ptr, idx.get()); - auto rhs_i = Stmt::make(rhs_ptr, idx.get()); - auto lhs_load = Stmt::make(lhs_i.get()); - auto rhs_load = Stmt::make(rhs_i.get()); - auto cur_lhs = lhs_load.get(); - auto cur_rhs = rhs_load.get(); - modifier.insert_before(stmt, std::move(idx)); - modifier.insert_before(stmt, std::move(lhs_i)); - modifier.insert_before(stmt, std::move(rhs_i)); - modifier.insert_before(stmt, std::move(lhs_load)); - modifier.insert_before(stmt, std::move(rhs_load)); - auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs) - : demote_ffloor(stmt, cur_lhs, cur_rhs); - ret_stmts.push_back(ret_i.get()); - modifier.insert_before(stmt, std::move(ret_i)); - } - auto new_matrix = Stmt::make(ret_stmts); - new_matrix->ret_type = stmt->ret_type; - stmt->replace_usages_with(new_matrix.get()); - modifier.insert_before(stmt, std::move(new_matrix)); - modifier.erase(stmt); } } else if (stmt->op_type == BinaryOpType::bit_shr) { // @ti.func @@ -156,17 +129,15 @@ class DemoteOperations : public BasicStmtVisitor { // shifted = ti.bit_sar(unsigned_a, b) // ret = ti.cast(shifted, ti.iXX) // return ret - TI_ASSERT(is_integral(lhs->element_type()) && - is_integral(rhs->element_type())); + TI_ASSERT(is_integral(lhs_prim_type) && is_integral(rhs_prim_type)); auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); - auto lhs_type = lhs->element_type(); unsigned_cast->as()->cast_type = - is_signed(lhs_type) ? to_unsigned(lhs_type) : lhs_type; + is_signed(lhs_prim_type) ? to_unsigned(lhs_prim_type) : lhs_prim_type; auto shift = Stmt::make(BinaryOpType::bit_sar, unsigned_cast.get(), rhs); auto signed_cast = Stmt::make(UnaryOpType::cast_bits, shift.get()); - signed_cast->as()->cast_type = lhs->element_type(); + signed_cast->as()->cast_type = lhs_prim_type; signed_cast->ret_type = stmt->ret_type; stmt->replace_usages_with(signed_cast.get()); modifier.insert_before(stmt, std::move(unsigned_cast)); @@ -174,7 +145,7 @@ class DemoteOperations : public BasicStmtVisitor { modifier.insert_before(stmt, std::move(signed_cast)); modifier.erase(stmt); } else if (stmt->op_type == BinaryOpType::pow && - is_integral(rhs->element_type())) { + is_integral(rhs_prim_type)) { // @ti.func // def pow(lhs, rhs): // a = lhs @@ -189,14 +160,14 @@ class DemoteOperations : public BasicStmtVisitor { // result = 1 / result # for real lhs // return result IRBuilder builder; - auto one_lhs = builder.get_constant(lhs->element_type(), 1); - auto one_rhs = builder.get_constant(rhs->element_type(), 1); - auto zero_rhs = builder.get_constant(rhs->element_type(), 0); - auto a = builder.create_local_var(lhs->element_type()); + auto one_lhs = builder.get_constant(lhs_prim_type, 1); + auto one_rhs = builder.get_constant(rhs_prim_type, 1); + auto zero_rhs = builder.get_constant(rhs_prim_type, 0); + auto a = builder.create_local_var(lhs_prim_type); builder.create_local_store(a, lhs); - auto b = builder.create_local_var(rhs->element_type()); + auto b = builder.create_local_var(rhs_prim_type); builder.create_local_store(b, builder.create_abs(rhs)); - auto result = builder.create_local_var(lhs->element_type()); + auto result = builder.create_local_var(lhs_prim_type); builder.create_local_store(result, one_lhs); auto loop = builder.create_while_true(); { @@ -222,7 +193,7 @@ class DemoteOperations : public BasicStmtVisitor { auto new_b = builder.create_sar(current_b, one_rhs); builder.create_local_store(b, new_b); } - if (is_real(lhs->element_type())) { + if (is_real(lhs_prim_type)) { auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs)); { auto _ = builder.get_if_guard(if_stmt, true);