Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ir] Update codegen for if while and assert to support type u1 #8003

Merged
merged 7 commits into from
May 16, 2023
4 changes: 4 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,10 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
value_type = tlctx->get_data_type(PrimitiveType::u16);
value = builder->CreateZExt(value, value_type);
}
if (dt->is_primitive(PrimitiveTypeID::u1)) {
value_type = tlctx->get_data_type(PrimitiveType::i32);
value = builder->CreateZExt(value, value_type);
}
return std::make_tuple(value, value_type);
}

Expand Down
35 changes: 22 additions & 13 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,6 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) {
UNARY_STD(tan)
UNARY_STD(tanh)
UNARY_STD(sgn)
UNARY_STD(logic_not)
UNARY_STD(acos)
UNARY_STD(asin)
UNARY_STD(cos)
Expand Down Expand Up @@ -524,6 +523,11 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) {
} else {
llvm_val[stmt] = builder->CreateNeg(input, "neg");
}
} else if (op == UnaryOpType::logic_not) {
llvm_val[stmt] = builder->CreateIsNull(input);
// TODO: (zhantong) remove this zero ext
llvm_val[stmt] = builder->CreateZExt(
llvm_val[stmt], tlctx->get_data_type(PrimitiveType::i32));
}
UNARY_INTRINSIC(round)
UNARY_INTRINSIC(floor)
Expand Down Expand Up @@ -618,6 +622,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
} else if (op == BinaryOpType::mod) {
llvm_val[stmt] =
builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::logical_and) {
llvm_val[stmt] =
builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::logical_or) {
llvm_val[stmt] =
builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::bit_and) {
llvm_val[stmt] =
builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
Expand Down Expand Up @@ -851,10 +861,9 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {

void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
builder->CreateTrunc(llvm_val[stmt->op1],
tlctx->get_data_type(PrimitiveType::u1)),
llvm_val[stmt->op2], llvm_val[stmt->op3]);
llvm_val[stmt] =
builder->CreateSelect(builder->CreateIsNotNull(llvm_val[stmt->op1]),
llvm_val[stmt->op2], llvm_val[stmt->op3]);
}

void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
Expand All @@ -865,9 +874,8 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
llvm::BasicBlock::Create(*llvm_context, "false_block", func);
llvm::BasicBlock *after_if =
llvm::BasicBlock::Create(*llvm_context, "after_if", func);
builder->CreateCondBr(
builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)),
true_block, false_block);
llvm::Value *cond = builder->CreateIsNotNull(llvm_val[if_stmt->cond]);
builder->CreateCondBr(cond, true_block, false_block);
builder->SetInsertPoint(true_block);
if (if_stmt->true_statements) {
if_stmt->true_statements->accept(this);
Expand Down Expand Up @@ -959,6 +967,9 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) {
if (dtype->is_primitive(PrimitiveTypeID::u8))
return builder->CreateZExt(to_print,
tlctx->get_data_type(PrimitiveType::u16));
if (dtype->is_primitive(PrimitiveTypeID::u1))
return builder->CreateZExt(to_print,
tlctx->get_data_type(PrimitiveType::i32));
return to_print;
};
for (auto i = 0; i < stmt->contents.size(); ++i) {
Expand Down Expand Up @@ -1054,8 +1065,7 @@ void TaskCodeGenLLVM::visit(WhileControlStmt *stmt) {
BasicBlock *after_break =
BasicBlock::Create(*llvm_context, "after_break", func);
TI_ASSERT(current_while_after_loop);
auto cond =
builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0));
auto *cond = builder->CreateIsNull(llvm_val[stmt->cond]);
builder->CreateCondBr(cond, current_while_after_loop, after_break);
builder->SetInsertPoint(after_break);
}
Expand Down Expand Up @@ -1309,7 +1319,7 @@ void TaskCodeGenLLVM::visit(AssertStmt *stmt) {

std::vector<llvm::Value *> args;
args.emplace_back(get_runtime());
args.emplace_back(llvm_val[stmt->cond]);
args.emplace_back(builder->CreateIsNotNull(llvm_val[stmt->cond]));
args.emplace_back(builder->CreateGlobalStringPtr(stmt->text));

for (int i = 0; i < stmt->args.size(); i++) {
Expand Down Expand Up @@ -2220,8 +2230,7 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) {
// test whether the current voxel is active or not
auto is_active = call(leaf_block, element.get("element"), "is_active",
{builder->CreateLoad(loop_index_ty, loop_index)});
is_active =
builder->CreateTrunc(is_active, llvm::Type::getInt1Ty(*llvm_context));
is_active = builder->CreateIsNotNull(is_active);
exec_cond = builder->CreateAnd(exec_cond, is_active);
}

Expand Down
10 changes: 6 additions & 4 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1652,9 +1652,10 @@ class TaskCodegen : public IRVisitor {
}

void visit(IfStmt *if_stmt) override {
spirv::Value cond_v = ir_->query_value(if_stmt->cond->raw_name());
spirv::Value cond_v = ir_->cast(
ir_->bool_type(), ir_->query_value(if_stmt->cond->raw_name()));
spirv::Value cond =
ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
ir_->ne(cond_v, ir_->cast(ir_->bool_type(), ir_->const_i32_zero_));
spirv::Label then_label = ir_->new_label();
spirv::Label merge_label = ir_->new_label();
spirv::Label else_label = ir_->new_label();
Expand Down Expand Up @@ -1776,9 +1777,10 @@ class TaskCodegen : public IRVisitor {
}

void visit(WhileControlStmt *stmt) override {
spirv::Value cond_v = ir_->query_value(stmt->cond->raw_name());
spirv::Value cond_v =
ir_->cast(ir_->bool_type(), ir_->query_value(stmt->cond->raw_name()));
spirv::Value cond =
ir_->eq(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
ir_->eq(cond_v, ir_->cast(ir_->bool_type(), ir_->const_i32_zero_));
spirv::Label then_label = ir_->new_label();
spirv::Label merge_label = ir_->new_label();

Expand Down
6 changes: 5 additions & 1 deletion taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) {
return llvm::ConstantFP::get(llvm::Type::getHalfTy(*ctx), (float32)t);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t));
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return t ? llvm::ConstantInt::getTrue(*ctx)
: llvm::ConstantInt::getFalse(*ctx);
} else if (is_integral(dt)) {
if (is_signed(dt)) {
return llvm::ConstantInt::get(
Expand Down Expand Up @@ -721,7 +724,8 @@ llvm::Value *TaichiLLVMContext::get_constant(T t) {
std::is_same_v<TargetType, float64>) {
return llvm::ConstantFP::get(*ctx, llvm::APFloat(t));
} else if (std::is_same_v<TargetType, bool>) {
return llvm::ConstantInt::get(*ctx, llvm::APInt(1, (uint64)t, true));
return t ? llvm::ConstantInt::getTrue(*ctx)
: llvm::ConstantInt::getFalse(*ctx);
} else if (std::is_same_v<TargetType, int32> ||
std::is_same_v<TargetType, uint32>) {
return llvm::ConstantInt::get(*ctx, llvm::APInt(32, (uint64)t, true));
Expand Down
12 changes: 6 additions & 6 deletions taichi/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ struct LLVMRuntime;

constexpr bool enable_assert = true;

void taichi_assert(RuntimeContext *context, i32 test, const char *msg);
void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg);
#define TI_ASSERT_INFO(x, msg) taichi_assert(context, (int)(x), msg)
void taichi_assert(RuntimeContext *context, u1 test, const char *msg);
void taichi_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg);
#define TI_ASSERT_INFO(x, msg) taichi_assert(context, (u1)(x), msg)
#define TI_ASSERT(x) TI_ASSERT_INFO(x, #x)

void ___stubs___() {
Expand Down Expand Up @@ -753,12 +753,12 @@ RUNTIME_STRUCT_FIELD(ListManager, num_elements);
RUNTIME_STRUCT_FIELD(ListManager, max_num_elements_per_chunk);
RUNTIME_STRUCT_FIELD(ListManager, element_size);

void taichi_assert(RuntimeContext *context, i32 test, const char *msg) {
void taichi_assert(RuntimeContext *context, u1 test, const char *msg) {
taichi_assert_runtime(context->runtime, test, msg);
}

void taichi_assert_format(LLVMRuntime *runtime,
i32 test,
u1 test,
const char *format,
int num_arguments,
uint64 *arguments) {
Expand Down Expand Up @@ -808,7 +808,7 @@ void taichi_assert_format(LLVMRuntime *runtime,
#endif
}

void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg) {
void taichi_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg) {
taichi_assert_format(runtime, test, msg, 0, nullptr);
}

Expand Down
36 changes: 17 additions & 19 deletions tests/python/test_pow.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,25 +62,23 @@ def foo(x: dt, y: ti.template()):
foo(10, -10)


# FIXME(zhantong): Uncomment this test after bool assertion is finished.
# @test_utils.test(
# debug=True,
# advanced_optimization=False,
# exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles],
# )
# def test_ipow_negative_exp_i32():
# _ipow_negative_exp(ti.i32)


# FIXME(zhantong): Uncomment this test after bool assertion is finished.
# @test_utils.test(
# debug=True,
# advanced_optimization=False,
# require=ti.extension.data64,
# exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles],
# )
# def test_ipow_negative_exp_i64():
# _ipow_negative_exp(ti.i64)
@test_utils.test(
debug=True,
advanced_optimization=False,
exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles],
)
def test_ipow_negative_exp_i32():
_ipow_negative_exp(ti.i32)


@test_utils.test(
debug=True,
advanced_optimization=False,
require=ti.extension.data64,
exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles],
)
def test_ipow_negative_exp_i64():
_ipow_negative_exp(ti.i64)


def _test_pow_int_base_int_exp(dt_base, dt_exp):
Expand Down