Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Lang] Migrate irpass::scalarize() after irpass::demote_operations() #8096

Merged
merged 13 commits into from
May 31, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
if (is_real(stmt->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateFDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (is_signed(stmt->ret_type)) {
} else if (is_signed(stmt->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateSDiv(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down Expand Up @@ -658,7 +658,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
builder->CreateShl(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
#endif
} else if (op == BinaryOpType::bit_sar) {
if (is_signed(stmt->lhs->element_type())) {
if (is_signed(stmt->lhs->ret_type.get_element_type())) {
llvm_val[stmt] =
builder->CreateAShr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else {
Expand Down
3 changes: 2 additions & 1 deletion taichi/ir/ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/ir/ir.h"
#include "taichi/ir/mesh.h"
#include "taichi/ir/statements.h"

namespace taichi::lang {

Expand Down Expand Up @@ -137,7 +138,7 @@ class IRBuilder {
ConstStmt *get_float64(float64 value);

template <typename T>
ConstStmt *get_constant(DataType dt, const T &value) {
Stmt *get_constant(DataType dt, const T &value) {
return insert(Stmt::make_typed<ConstStmt>(TypedConstant(dt, value)));
}

Expand Down
9 changes: 9 additions & 0 deletions taichi/transforms/compile_to_offloads.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,15 @@ void offload_to_executable(IRNode *ir,
irpass::demote_operations(ir, config);
print("Operations demoted");

if (config.real_matrix_scalarize) {
if (irpass::scalarize(ir)) {
// Remove redundant MatrixInitStmt inserted during scalarization
irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Scalarized");
}
}

irpass::full_simplify(ir, config,
{lower_global_access, /*autodiff_enabled*/ false});
print("Simplified IV");
Expand Down
247 changes: 148 additions & 99 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,134 @@ class DemoteOperations : public BasicStmtVisitor {
DemoteOperations() {
}

Stmt *transform_pow_op_impl(IRBuilder &builder, Stmt *lhs, Stmt *rhs) {
auto lhs_type = lhs->ret_type.get_element_type();
auto rhs_type = rhs->ret_type.get_element_type();

auto one_lhs = builder.get_constant(lhs_type, 1);
auto one_rhs = builder.get_constant(rhs_type, 1);
auto zero_rhs = builder.get_constant(rhs_type, 0);
auto a = builder.create_local_var(lhs_type);
builder.create_local_store(a, lhs);
auto b = builder.create_local_var(rhs_type);
builder.create_local_store(b, builder.create_abs(rhs));
auto result = builder.create_local_var(lhs_type);
builder.create_local_store(result, one_lhs);
auto loop = builder.create_while_true();
{
auto loop_guard = builder.get_loop_guard(loop);
auto current_a = builder.create_local_load(a);
auto current_b = builder.create_local_load(b);
auto if_stmt =
builder.create_if(builder.create_cmp_le(current_b, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
builder.create_break();
}
auto bit_and = builder.create_and(current_b, one_rhs);
if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
auto current_result = builder.create_local_load(result);
auto new_result = builder.create_mul(current_result, current_a);
builder.create_local_store(result, new_result);
}
auto new_a = builder.create_mul(current_a, current_a);
builder.create_local_store(a, new_a);
auto new_b = builder.create_sar(current_b, one_rhs);
builder.create_local_store(b, new_b);
}
if (is_real(lhs_type)) {
auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
auto current_result = builder.create_local_load(result);
auto new_result = builder.create_div(one_lhs, current_result);
builder.create_local_store(result, new_result);
}
}
auto final_result = builder.create_local_load(result);
return final_result;
}

void transform_pow_op_scalar(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) {
IRBuilder builder;

auto final_result = transform_pow_op_impl(builder, lhs, rhs);

stmt->replace_usages_with(final_result);
modifier.insert_before(
stmt, VecStatement(std::move(builder.extract_ir()->statements)));
modifier.erase(stmt);
}

void transform_pow_op_tensor(BinaryOpStmt *stmt, Stmt *lhs, Stmt *rhs) {
std::vector<Stmt *> ret_stmts;
auto lhs_tensor_ty = lhs->ret_type->cast<TensorType>();
auto rhs_tensor_ty = rhs->ret_type->cast<TensorType>();

auto lhs_prim_type = lhs_tensor_ty->get_element_type();
auto rhs_prim_type = rhs_tensor_ty->get_element_type();

auto lhs_alloca = Stmt::make<AllocaStmt>(lhs_tensor_ty);
auto rhs_alloca = Stmt::make<AllocaStmt>(rhs_tensor_ty);
auto lhs_store = Stmt::make<LocalStoreStmt>(lhs_alloca.get(), stmt->lhs);
auto rhs_store = Stmt::make<LocalStoreStmt>(rhs_alloca.get(), stmt->rhs);
auto lhs_ptr = lhs_alloca.get();
auto rhs_ptr = rhs_alloca.get();
modifier.insert_before(stmt, std::move(lhs_alloca));
modifier.insert_before(stmt, std::move(rhs_alloca));
modifier.insert_before(stmt, std::move(lhs_store));
modifier.insert_before(stmt, std::move(rhs_store));
for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) {
auto idx = Stmt::make<ConstStmt>(TypedConstant(i));
auto lhs_i = Stmt::make<MatrixPtrStmt>(lhs_ptr, idx.get());
auto rhs_i = Stmt::make<MatrixPtrStmt>(rhs_ptr, idx.get());
auto lhs_load = Stmt::make<LocalLoadStmt>(lhs_i.get());
lhs_load->ret_type = lhs_prim_type;

auto rhs_load = Stmt::make<LocalLoadStmt>(rhs_i.get());
rhs_load->ret_type = rhs_prim_type;

auto cur_lhs = lhs_load.get();
auto cur_rhs = rhs_load.get();
modifier.insert_before(stmt, std::move(idx));
modifier.insert_before(stmt, std::move(lhs_i));
modifier.insert_before(stmt, std::move(rhs_i));
modifier.insert_before(stmt, std::move(lhs_load));
modifier.insert_before(stmt, std::move(rhs_load));

IRBuilder builder;
auto cur_result = transform_pow_op_impl(builder, cur_lhs, cur_rhs);

modifier.insert_before(
stmt, VecStatement(std::move(builder.extract_ir()->statements)));
ret_stmts.push_back(cur_result);
}
auto new_matrix = Stmt::make<MatrixInitStmt>(ret_stmts);
new_matrix->ret_type = stmt->ret_type;
stmt->replace_usages_with(new_matrix.get());
modifier.insert_before(stmt, std::move(new_matrix));
modifier.erase(stmt);
}

std::unique_ptr<Stmt> demote_ifloordiv(BinaryOpStmt *stmt,
Stmt *lhs,
Stmt *rhs) {
auto ret = Stmt::make<BinaryOpStmt>(BinaryOpType::div, lhs, rhs);
auto zero = Stmt::make<ConstStmt>(TypedConstant(0));

if (lhs->ret_type->is<TensorType>()) {
int num_elements = lhs->ret_type->cast<TensorType>()->get_num_elements();
std::vector<Stmt *> values(num_elements, zero.get());

auto matrix_zero = Stmt::make<MatrixInitStmt>(values);
matrix_zero->ret_type = lhs->ret_type;

modifier.insert_before(stmt, std::move(zero));
zero = std::move(matrix_zero);
}

auto lhs_ltz =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_lt, lhs, zero.get());
auto rhs_ltz =
Expand All @@ -39,6 +162,7 @@ class DemoteOperations : public BasicStmtVisitor {
cond12.get(), cond3.get());
auto real_ret =
Stmt::make<BinaryOpStmt>(BinaryOpType::sub, ret.get(), cond.get());

modifier.insert_before(stmt, std::move(ret));
modifier.insert_before(stmt, std::move(zero));
modifier.insert_before(stmt, std::move(lhs_ltz));
Expand All @@ -64,9 +188,14 @@ class DemoteOperations : public BasicStmtVisitor {
void visit(BinaryOpStmt *stmt) override {
auto lhs = stmt->lhs;
auto rhs = stmt->rhs;

auto lhs_type = lhs->ret_type;
auto rhs_type = rhs->ret_type;

auto lhs_prim_type = lhs_type.get_element_type();
auto rhs_prim_type = rhs_type.get_element_type();
if (stmt->op_type == BinaryOpType::floordiv) {
if (is_integral(rhs->element_type()) &&
is_integral(lhs->element_type())) {
if (is_integral(rhs_prim_type) && is_integral(lhs_prim_type)) {
// @ti.func
// def ifloordiv(a, b):
// r = ti.raw_div(a, b)
Expand Down Expand Up @@ -96,7 +225,7 @@ class DemoteOperations : public BasicStmtVisitor {
modifier.insert_before(stmt, std::move(real_ret));
modifier.erase(stmt);

} else if (is_real(rhs->element_type()) || is_real(lhs->element_type())) {
} else if (is_real(rhs_prim_type) || is_real(lhs_prim_type)) {
// @ti.func
// def ffloordiv(a, b):
// r = ti.raw_div(a, b)
Expand All @@ -106,48 +235,6 @@ class DemoteOperations : public BasicStmtVisitor {
stmt->replace_usages_with(floor.get());
modifier.insert_before(stmt, std::move(floor));
modifier.erase(stmt);
} else if (lhs->ret_type->is<TensorType>() &&
rhs->ret_type->is<TensorType>()) {
bool use_integral = is_integral(lhs->ret_type.get_element_type()) &&
is_integral(rhs->ret_type.get_element_type());
std::vector<Stmt *> ret_stmts;
auto lhs_tensor_ty = lhs->ret_type->cast<TensorType>();
auto rhs_tensor_ty = rhs->ret_type->cast<TensorType>();
auto lhs_alloca = Stmt::make<AllocaStmt>(lhs_tensor_ty);
auto rhs_alloca = Stmt::make<AllocaStmt>(rhs_tensor_ty);
auto lhs_store =
Stmt::make<LocalStoreStmt>(lhs_alloca.get(), stmt->lhs);
auto rhs_store =
Stmt::make<LocalStoreStmt>(rhs_alloca.get(), stmt->rhs);
auto lhs_ptr = lhs_alloca.get();
auto rhs_ptr = rhs_alloca.get();
modifier.insert_before(stmt, std::move(lhs_alloca));
modifier.insert_before(stmt, std::move(rhs_alloca));
modifier.insert_before(stmt, std::move(lhs_store));
modifier.insert_before(stmt, std::move(rhs_store));
for (int i = 0; i < lhs_tensor_ty->get_num_elements(); i++) {
auto idx = Stmt::make<ConstStmt>(TypedConstant(i));
auto lhs_i = Stmt::make<MatrixPtrStmt>(lhs_ptr, idx.get());
auto rhs_i = Stmt::make<MatrixPtrStmt>(rhs_ptr, idx.get());
auto lhs_load = Stmt::make<LocalLoadStmt>(lhs_i.get());
auto rhs_load = Stmt::make<LocalLoadStmt>(rhs_i.get());
auto cur_lhs = lhs_load.get();
auto cur_rhs = rhs_load.get();
modifier.insert_before(stmt, std::move(idx));
modifier.insert_before(stmt, std::move(lhs_i));
modifier.insert_before(stmt, std::move(rhs_i));
modifier.insert_before(stmt, std::move(lhs_load));
modifier.insert_before(stmt, std::move(rhs_load));
auto ret_i = use_integral ? demote_ifloordiv(stmt, cur_lhs, cur_rhs)
: demote_ffloor(stmt, cur_lhs, cur_rhs);
ret_stmts.push_back(ret_i.get());
modifier.insert_before(stmt, std::move(ret_i));
}
auto new_matrix = Stmt::make<MatrixInitStmt>(ret_stmts);
new_matrix->ret_type = stmt->ret_type;
stmt->replace_usages_with(new_matrix.get());
modifier.insert_before(stmt, std::move(new_matrix));
modifier.erase(stmt);
}
} else if (stmt->op_type == BinaryOpType::bit_shr) {
// @ti.func
Expand All @@ -156,25 +243,26 @@ class DemoteOperations : public BasicStmtVisitor {
// shifted = ti.bit_sar(unsigned_a, b)
// ret = ti.cast(shifted, ti.iXX)
// return ret
TI_ASSERT(is_integral(lhs->element_type()) &&
is_integral(rhs->element_type()));
TI_ASSERT(is_integral(lhs_prim_type) && is_integral(rhs_prim_type));
auto unsigned_cast = Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, lhs);
auto lhs_type = lhs->element_type();
unsigned_cast->as<UnaryOpStmt>()->cast_type =
is_signed(lhs_type) ? to_unsigned(lhs_type) : lhs_type;
is_signed(lhs_prim_type) ? to_unsigned(lhs_prim_type) : lhs_prim_type;
auto shift = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_sar,
unsigned_cast.get(), rhs);
auto signed_cast =
Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, shift.get());
signed_cast->as<UnaryOpStmt>()->cast_type = lhs->element_type();
signed_cast->as<UnaryOpStmt>()->cast_type = lhs_prim_type;
signed_cast->ret_type = stmt->ret_type;
stmt->replace_usages_with(signed_cast.get());
modifier.insert_before(stmt, std::move(unsigned_cast));
modifier.insert_before(stmt, std::move(shift));
modifier.insert_before(stmt, std::move(signed_cast));
modifier.erase(stmt);
} else if (stmt->op_type == BinaryOpType::pow &&
is_integral(rhs->element_type())) {
} else if (stmt->op_type == BinaryOpType::pow) {
// There's no direct support for Power operation in LLVM / SpirV IR.
// We need to manually transform it to make it work.

// [Transform]
// @ti.func
// def pow(lhs, rhs):
// a = lhs
Expand All @@ -188,54 +276,15 @@ class DemoteOperations : public BasicStmtVisitor {
// if rhs < 0: # for real lhs
// result = 1 / result # for real lhs
// return result
IRBuilder builder;
auto one_lhs = builder.get_constant(lhs->element_type(), 1);
auto one_rhs = builder.get_constant(rhs->element_type(), 1);
auto zero_rhs = builder.get_constant(rhs->element_type(), 0);
auto a = builder.create_local_var(lhs->element_type());
builder.create_local_store(a, lhs);
auto b = builder.create_local_var(rhs->element_type());
builder.create_local_store(b, builder.create_abs(rhs));
auto result = builder.create_local_var(lhs->element_type());
builder.create_local_store(result, one_lhs);
auto loop = builder.create_while_true();
{
auto loop_guard = builder.get_loop_guard(loop);
auto current_a = builder.create_local_load(a);
auto current_b = builder.create_local_load(b);
auto if_stmt =
builder.create_if(builder.create_cmp_le(current_b, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
builder.create_break();
}
auto bit_and = builder.create_and(current_b, one_rhs);
if_stmt = builder.create_if(builder.create_cmp_ne(bit_and, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
auto current_result = builder.create_local_load(result);
auto new_result = builder.create_mul(current_result, current_a);
builder.create_local_store(result, new_result);
}
auto new_a = builder.create_mul(current_a, current_a);
builder.create_local_store(a, new_a);
auto new_b = builder.create_sar(current_b, one_rhs);
builder.create_local_store(b, new_b);
}
if (is_real(lhs->element_type())) {
auto if_stmt = builder.create_if(builder.create_cmp_le(rhs, zero_rhs));
{
auto _ = builder.get_if_guard(if_stmt, true);
auto current_result = builder.create_local_load(result);
auto new_result = builder.create_div(one_lhs, current_result);
builder.create_local_store(result, new_result);
}
if (is_integral(rhs_type)) {
transform_pow_op_scalar(stmt, lhs, rhs);
} else if (rhs_type->is<TensorType>() && lhs_type->is<TensorType>() &&
is_integral(rhs_type.get_element_type())) {
// For Power with TensorType'd operands, since IfStmt and WhileStmt
// isn't compatible with TensorType'd condition statement,
// we have to perform immediate scalarization with help from AllocaStmt.
transform_pow_op_tensor(stmt, lhs, rhs);
}
auto final_result = builder.create_local_load(result);
stmt->replace_usages_with(final_result);
modifier.insert_before(
stmt, VecStatement(std::move(builder.extract_ir()->statements)));
modifier.erase(stmt);
}
}

Expand Down