Skip to content

Commit

Permalink
[Lang] Migrate irpass::scalarize() after irpass::demote_operations()
Browse files Browse the repository at this point in the history
  • Loading branch information
jim19930609 committed May 29, 2023
1 parent b211db5 commit e852b32
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 63 deletions.
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
if (is_real(stmt->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (is_signed(stmt->ret_type)) {
} else if (is_signed(stmt->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down Expand Up @@ -658,7 +658,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
#endif
} else if (op == BinaryOpType::bit_sar) {
if (is_signed(stmt->lhs->element_type())) {
if (is_signed(stmt->lhs->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down
6 changes: 3 additions & 3 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,9 @@ void offload_to_executable(IRNode *ir,
irpass::analysis::verify(ir);
}

irpass::demote_operations(ir, config);
print("Operations demoted");

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
Expand All @@ -277,9 +280,6 @@ void offload_to_executable(IRNode *ir,
}
}

irpass::demote_operations(ir, config);
print("Operations demoted");

irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Simplified IV");
Expand Down
87 changes: 29 additions & 58 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@ class DemoteOperations : public BasicStmtVisitor {
Stmt *rhs) {
auto ret = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs);
auto zero = Stmt::make<ConstStmt>(TypedConstant(0));

if (lhs->ret_type->is<TensorType>()) {
int num_elements = lhs->ret_type->cast<TensorType>()->get_num_elements();
std::vector<Stmt *> values(num_elements, zero.get());

auto matrix_zero = Stmt::make<MatrixInitStmt>(values);
matrix_zero->ret_type = lhs->ret_type;

modifier.insert_before(stmt, std::move(zero));
zero = std::move(matrix_zero);
}

auto lhs_ltz =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, lhs, zero.get());
auto rhs_ltz =
Expand All @@ -39,6 +51,7 @@ class DemoteOperations : public BasicStmtVisitor {
cond12.get(), cond3.get());
auto real_ret =
Stmt::make<BinaryOpStmt>(BinaryOpType::sub, ret.get(), cond.get());

modifier.insert_before(stmt, std::move(ret));
modifier.insert_before(stmt, std::move(zero));
modifier.insert_before(stmt, std::move(lhs_ltz));
Expand All @@ -64,9 +77,11 @@ class DemoteOperations : public BasicStmtVisitor {
void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs;
auto rhs = stmt->rhs;

auto lhs_prim_type = lhs->ret_type.get_element_type();
auto rhs_prim_type = rhs->ret_type.get_element_type();
if (stmt->op_type == BinaryOpType::floordiv) {
if (is_integral(rhs->element_type()) &&
is_integral(lhs->element_type())) {
if (is_integral(rhs_prim_type) && is_integral(lhs_prim_type)) {
// @ti.func
// def ifloordiv(a, b):
// r = ti.raw_div(a, b)
Expand Down Expand Up @@ -96,7 +111,7 @@ class DemoteOperations : public BasicStmtVisitor {
modifier.insert_before(stmt, std::move(real_ret));
modifier.erase(stmt);

} else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) {
} else if (is_real(rhs_prim_type) || is_real(lhs_prim_type)) {
// @ti.func
// def ffloordiv(a, b):
// r = ti.raw_div(a, b)
Expand All @@ -106,48 +121,6 @@ class DemoteOperations : public BasicStmtVisitor {
stmt->replace_usages_with(floor.get());
modifier.insert_before(stmt, std::move(floor));
modifier.erase(stmt);
} else if (lhs->ret_type->is<TensorType>() &&
rhs->ret_type->is<TensorType>()) {
bool use_integral = is_integral(lhs->ret_type.get_element_type()) &&
is_integral(rhs->ret_type.get_element_type());
std::vector<Stmt *> ret_stmts;
auto lhs_tensor_ty = lhs->ret_type->cast<TensorType>();
auto rhs_tensor_ty = rhs->ret_type->cast<TensorType>();
auto lhs_alloca = Stmt::make<AllocaStmt>(lhs_tensor_ty);
auto rhs_alloca = Stmt::make<AllocaStmt>(rhs_tensor_ty);
auto lhs_store =
Stmt::make<LocalStoreStmt>(lhs_alloca.get(), stmt->lhs);
auto rhs_store =
Stmt::make<LocalStoreStmt>(rhs_alloca.get(), stmt->rhs);
auto lhs_ptr = lhs_alloca.get();
auto rhs_ptr = rhs_alloca.get();
modifier.insert_before(stmt, std::move(lhs_alloca));
modifier.insert_before(stmt, std::move(rhs_alloca));
modifier.insert_before(stmt, std::move(lhs_store));
modifier.insert_before(stmt, std::move(rhs_store));
for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) {
auto idx = Stmt::make<ConstStmt>(TypedConstant(i));
auto lhs_i = Stmt::make<MatrixPtrStmt>(lhs_ptr, idx.get());
auto rhs_i = Stmt::make<MatrixPtrStmt>(rhs_ptr, idx.get());
auto lhs_load = Stmt::make<LocalLoadStmt>(lhs_i.get());
auto rhs_load = Stmt::make<LocalLoadStmt>(rhs_i.get());
auto cur_lhs = lhs_load.get();
auto cur_rhs = rhs_load.get();
modifier.insert_before(stmt, std::move(idx));
modifier.insert_before(stmt, std::move(lhs_i));
modifier.insert_before(stmt, std::move(rhs_i));
modifier.insert_before(stmt, std::move(lhs_load));
modifier.insert_before(stmt, std::move(rhs_load));
auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs)
: demote_ffloor(stmt, cur_lhs, cur_rhs);
ret_stmts.push_back(ret_i.get());
modifier.insert_before(stmt, std::move(ret_i));
}
auto new_matrix = Stmt::make<MatrixInitStmt>(ret_stmts);
new_matrix->ret_type = stmt->ret_type;
stmt->replace_usages_with(new_matrix.get());
modifier.insert_before(stmt, std::move(new_matrix));
modifier.erase(stmt);
}
} else if (stmt->op_type == BinaryOpType::bit_shr) {
// @ti.func
Expand All @@ -156,25 +129,23 @@ class DemoteOperations : public BasicStmtVisitor {
// shifted = ti.bit_sar(unsigned_a, b)
// ret = ti.cast(shifted, ti.iXX)
// return ret
TI_ASSERT(is_integral(lhs->element_type()) &&
is_integral(rhs->element_type()));
TI_ASSERT(is_integral(lhs_prim_type) && is_integral(rhs_prim_type));
auto unsigned_cast = Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, lhs);
auto lhs_type = lhs->element_type();
unsigned_cast->as<UnaryOpStmt>()->cast_type =
is_signed(lhs_type) ? to_unsigned(lhs_type) : lhs_type;
is_signed(lhs_prim_type) ? to_unsigned(lhs_prim_type) : lhs_prim_type;
auto shift = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar,
unsigned_cast.get(), rhs);
auto signed_cast =
Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, shift.get());
signed_cast->as<UnaryOpStmt>()->cast_type = lhs->element_type();
signed_cast->as<UnaryOpStmt>()->cast_type = lhs_prim_type;
signed_cast->ret_type = stmt->ret_type;
stmt->replace_usages_with(signed_cast.get());
modifier.insert_before(stmt, std::move(unsigned_cast));
modifier.insert_before(stmt, std::move(shift));
modifier.insert_before(stmt, std::move(signed_cast));
modifier.erase(stmt);
} else if (stmt->op_type == BinaryOpType::pow &&
is_integral(rhs->element_type())) {
is_integral(rhs_prim_type)) {
// @ti.func
// def pow(lhs, rhs):
// a = lhs
Expand All @@ -189,14 +160,14 @@ class DemoteOperations : public BasicStmtVisitor {
// result = 1 / result # for real lhs
// return result
IRBuilder builder;
auto one_lhs = builder.get_constant(lhs->element_type(), 1);
auto one_rhs = builder.get_constant(rhs->element_type(), 1);
auto zero_rhs = builder.get_constant(rhs->element_type(), 0);
auto a = builder.create_local_var(lhs->element_type());
auto one_lhs = builder.get_constant(lhs_prim_type, 1);
auto one_rhs = builder.get_constant(rhs_prim_type, 1);
auto zero_rhs = builder.get_constant(rhs_prim_type, 0);
auto a = builder.create_local_var(lhs_prim_type);
builder.create_local_store(a, lhs);
auto b = builder.create_local_var(rhs->element_type());
auto b = builder.create_local_var(rhs_prim_type);
builder.create_local_store(b, builder.create_abs(rhs));
auto result = builder.create_local_var(lhs->element_type());
auto result = builder.create_local_var(lhs_prim_type);
builder.create_local_store(result, one_lhs);
auto loop = builder.create_while_true();
{
Expand All @@ -222,7 +193,7 @@ class DemoteOperations : public BasicStmtVisitor {
auto new_b = builder.create_sar(current_b, one_rhs);
builder.create_local_store(b, new_b);
}
if (is_real(lhs->element_type())) {
if (is_real(lhs_prim_type)) {
auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
Expand Down

0 comments on commit e852b32

Please sign in to comment.