Skip to content

Commit

Permalink
[ir] Update codegen for if while assert to support type u1.
Browse files Browse the repository at this point in the history
ghstack-source-id: 846dcff806df68b40d6959ed712c7677fb36a338
Pull Request resolved: #8003
  • Loading branch information
listerily committed May 15, 2023
1 parent 6456d4f commit a54c548
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 20 deletions.
7 changes: 7 additions & 0 deletions taichi/codegen/cuda/codegen_cuda.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,13 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM {
value_type = tlctx->get_data_type(PrimitiveType::u16);
value = builder->CreateZExt(value, value_type);
}
if (dt->is_primitive(PrimitiveTypeID::u1)) {
auto char_type = llvm::Type::getInt8Ty(*tlctx->get_this_thread_context());
value_type = llvm::PointerType::get(char_type, 0);
value = builder->CreateSelect(
value, builder->CreateGlobalStringPtr("True", "u1_true_value"),
builder->CreateGlobalStringPtr("False", "u1_false_value"));
}
return std::make_tuple(value, value_type);
}

Expand Down
24 changes: 15 additions & 9 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -852,7 +852,7 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) {
TI_ASSERT(stmt->op_type == TernaryOpType::select);
llvm_val[stmt] = builder->CreateSelect(
builder->CreateTrunc(llvm_val[stmt->op1],
builder->CreateTrunc(builder->CreateIsNotNull(llvm_val[stmt->op1]),
tlctx->get_data_type(PrimitiveType::u1)),
llvm_val[stmt->op2], llvm_val[stmt->op3]);
}
Expand All @@ -865,9 +865,10 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) {
llvm::BasicBlock::Create(*llvm_context, "false_block", func);
llvm::BasicBlock *after_if =
llvm::BasicBlock::Create(*llvm_context, "after_if", func);
builder->CreateCondBr(
builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)),
true_block, false_block);
llvm::Value *cond =
builder->CreateTrunc(builder->CreateIsNotNull(llvm_val[if_stmt->cond]),
tlctx->get_data_type(PrimitiveType::u1));
builder->CreateCondBr(cond, true_block, false_block);
builder->SetInsertPoint(true_block);
if (if_stmt->true_statements) {
if_stmt->true_statements->accept(this);
Expand Down Expand Up @@ -959,6 +960,9 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) {
if (dtype->is_primitive(PrimitiveTypeID::u8))
return builder->CreateZExt(to_print,
tlctx->get_data_type(PrimitiveType::u16));
if (dtype->is_primitive(PrimitiveTypeID::u1))
return builder->CreateZExt(to_print,
tlctx->get_data_type(PrimitiveType::i32));
return to_print;
};
for (auto i = 0; i < stmt->contents.size(); ++i) {
Expand Down Expand Up @@ -1054,8 +1058,8 @@ void TaskCodeGenLLVM::visit(WhileControlStmt *stmt) {
BasicBlock *after_break =
BasicBlock::Create(*llvm_context, "after_break", func);
TI_ASSERT(current_while_after_loop);
auto cond =
builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0));
auto *cond = builder->CreateTrunc(builder->CreateIsNull(llvm_val[stmt->cond]),
tlctx->get_data_type(PrimitiveType::u1));
builder->CreateCondBr(cond, current_while_after_loop, after_break);
builder->SetInsertPoint(after_break);
}
Expand Down Expand Up @@ -1309,7 +1313,9 @@ void TaskCodeGenLLVM::visit(AssertStmt *stmt) {

std::vector<llvm::Value *> args;
args.emplace_back(get_runtime());
args.emplace_back(llvm_val[stmt->cond]);
args.emplace_back(
builder->CreateTrunc(builder->CreateIsNotNull(llvm_val[stmt->cond]),
tlctx->get_data_type(PrimitiveType::u1)));
args.emplace_back(builder->CreateGlobalStringPtr(stmt->text));

for (int i = 0; i < stmt->args.size(); i++) {
Expand Down Expand Up @@ -2220,8 +2226,8 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) {
// test whether the current voxel is active or not
auto is_active = call(leaf_block, element.get("element"), "is_active",
{builder->CreateLoad(loop_index_ty, loop_index)});
is_active =
builder->CreateTrunc(is_active, llvm::Type::getInt1Ty(*llvm_context));
is_active = builder->CreateTrunc(builder->CreateIsNotNull(is_active),
llvm::Type::getInt1Ty(*llvm_context));
exec_cond = builder->CreateAnd(exec_cond, is_active);
}

Expand Down
10 changes: 6 additions & 4 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1652,9 +1652,10 @@ class TaskCodegen : public IRVisitor {
}

void visit(IfStmt *if_stmt) override {
spirv::Value cond_v = ir_->query_value(if_stmt->cond->raw_name());
spirv::Value cond_v = ir_->cast(
ir_->bool_type(), ir_->query_value(if_stmt->cond->raw_name()));
spirv::Value cond =
ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
ir_->ne(cond_v, ir_->cast(ir_->bool_type(), ir_->const_i32_zero_));
spirv::Label then_label = ir_->new_label();
spirv::Label merge_label = ir_->new_label();
spirv::Label else_label = ir_->new_label();
Expand Down Expand Up @@ -1776,9 +1777,10 @@ class TaskCodegen : public IRVisitor {
}

void visit(WhileControlStmt *stmt) override {
spirv::Value cond_v = ir_->query_value(stmt->cond->raw_name());
spirv::Value cond_v =
ir_->cast(ir_->bool_type(), ir_->query_value(stmt->cond->raw_name()));
spirv::Value cond =
ir_->eq(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_));
ir_->eq(cond_v, ir_->cast(ir_->bool_type(), ir_->const_i32_zero_));
spirv::Label then_label = ir_->new_label();
spirv::Label merge_label = ir_->new_label();

Expand Down
6 changes: 5 additions & 1 deletion taichi/runtime/llvm/llvm_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -692,6 +692,9 @@ llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) {
return llvm::ConstantFP::get(llvm::Type::getHalfTy(*ctx), (float32)t);
} else if (dt->is_primitive(PrimitiveTypeID::f64)) {
return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t));
} else if (dt->is_primitive(PrimitiveTypeID::u1)) {
return t ? llvm::ConstantInt::getTrue(*ctx)
: llvm::ConstantInt::getFalse(*ctx);
} else if (is_integral(dt)) {
if (is_signed(dt)) {
return llvm::ConstantInt::get(
Expand Down Expand Up @@ -721,7 +724,8 @@ llvm::Value *TaichiLLVMContext::get_constant(T t) {
std::is_same_v<TargetType, float64>) {
return llvm::ConstantFP::get(*ctx, llvm::APFloat(t));
} else if (std::is_same_v<TargetType, bool>) {
return llvm::ConstantInt::get(*ctx, llvm::APInt(1, (uint64)t, true));
return t ? llvm::ConstantInt::getTrue(*ctx)
: llvm::ConstantInt::getFalse(*ctx);
} else if (std::is_same_v<TargetType, int32> ||
std::is_same_v<TargetType, uint32>) {
return llvm::ConstantInt::get(*ctx, llvm::APInt(32, (uint64)t, true));
Expand Down
12 changes: 6 additions & 6 deletions taichi/runtime/llvm/runtime_module/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -332,9 +332,9 @@ struct LLVMRuntime;

constexpr bool enable_assert = true;

void taichi_assert(RuntimeContext *context, i32 test, const char *msg);
void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg);
#define TI_ASSERT_INFO(x, msg) taichi_assert(context, (int)(x), msg)
void taichi_assert(RuntimeContext *context, u1 test, const char *msg);
void taichi_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg);
#define TI_ASSERT_INFO(x, msg) taichi_assert(context, (u1)(x), msg)
#define TI_ASSERT(x) TI_ASSERT_INFO(x, #x)

void ___stubs___() {
Expand Down Expand Up @@ -753,12 +753,12 @@ RUNTIME_STRUCT_FIELD(ListManager, num_elements);
RUNTIME_STRUCT_FIELD(ListManager, max_num_elements_per_chunk);
RUNTIME_STRUCT_FIELD(ListManager, element_size);

void taichi_assert(RuntimeContext *context, i32 test, const char *msg) {
void taichi_assert(RuntimeContext *context, u1 test, const char *msg) {
taichi_assert_runtime(context->runtime, test, msg);
}

void taichi_assert_format(LLVMRuntime *runtime,
i32 test,
u1 test,
const char *format,
int num_arguments,
uint64 *arguments) {
Expand Down Expand Up @@ -808,7 +808,7 @@ void taichi_assert_format(LLVMRuntime *runtime,
#endif
}

void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg) {
void taichi_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg) {
taichi_assert_format(runtime, test, msg, 0, nullptr);
}

Expand Down

0 comments on commit a54c548

Please sign in to comment.