diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index bacad80ee53d8..46bbf331e4c5d 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -94,6 +94,10 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { value_type = tlctx->get_data_type(PrimitiveType::u16); value = builder->CreateZExt(value, value_type); } + if (dt->is_primitive(PrimitiveTypeID::u1)) { + value_type = tlctx->get_data_type(PrimitiveType::i32); + value = builder->CreateZExt(value, value_type); + } return std::make_tuple(value, value_type); } diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 2fefbeb247892..bb249983483c4 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -197,7 +197,6 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { UNARY_STD(tan) UNARY_STD(tanh) UNARY_STD(sgn) - UNARY_STD(logic_not) UNARY_STD(acos) UNARY_STD(asin) UNARY_STD(cos) @@ -524,6 +523,11 @@ void TaskCodeGenLLVM::visit(UnaryOpStmt *stmt) { } else { llvm_val[stmt] = builder->CreateNeg(input, "neg"); } + } else if (op == UnaryOpType::logic_not) { + llvm_val[stmt] = builder->CreateIsNull(input); + // TODO: (zhantong) remove this zero ext + llvm_val[stmt] = builder->CreateZExt( + llvm_val[stmt], tlctx->get_data_type(PrimitiveType::i32)); } UNARY_INTRINSIC(round) UNARY_INTRINSIC(floor) @@ -618,6 +622,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } else if (op == BinaryOpType::mod) { llvm_val[stmt] = builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } else if (op == BinaryOpType::logical_and) { + llvm_val[stmt] = + builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } else if (op == BinaryOpType::logical_or) { + llvm_val[stmt] = + builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_and) { llvm_val[stmt] = builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); @@ -851,10 +861,9 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { void TaskCodeGenLLVM::visit(TernaryOpStmt *stmt) { TI_ASSERT(stmt->op_type == TernaryOpType::select); - llvm_val[stmt] = builder->CreateSelect( - builder->CreateTrunc(llvm_val[stmt->op1], - tlctx->get_data_type(PrimitiveType::u1)), - llvm_val[stmt->op2], llvm_val[stmt->op3]); + llvm_val[stmt] = + builder->CreateSelect(builder->CreateIsNotNull(llvm_val[stmt->op1]), + llvm_val[stmt->op2], llvm_val[stmt->op3]); } void TaskCodeGenLLVM::visit(IfStmt *if_stmt) { @@ -865,9 +874,8 @@ void TaskCodeGenLLVM::visit(IfStmt *if_stmt) { llvm::BasicBlock::Create(*llvm_context, "false_block", func); llvm::BasicBlock *after_if = llvm::BasicBlock::Create(*llvm_context, "after_if", func); - builder->CreateCondBr( - builder->CreateICmpNE(llvm_val[if_stmt->cond], tlctx->get_constant(0)), - true_block, false_block); + llvm::Value *cond = builder->CreateIsNotNull(llvm_val[if_stmt->cond]); + builder->CreateCondBr(cond, true_block, false_block); builder->SetInsertPoint(true_block); if (if_stmt->true_statements) { if_stmt->true_statements->accept(this); @@ -959,6 +967,9 @@ void TaskCodeGenLLVM::visit(PrintStmt *stmt) { if (dtype->is_primitive(PrimitiveTypeID::u8)) return builder->CreateZExt(to_print, tlctx->get_data_type(PrimitiveType::u16)); + if (dtype->is_primitive(PrimitiveTypeID::u1)) + return builder->CreateZExt(to_print, + tlctx->get_data_type(PrimitiveType::i32)); return to_print; }; for (auto i = 0; i < stmt->contents.size(); ++i) { @@ -1054,8 +1065,7 @@ void TaskCodeGenLLVM::visit(WhileControlStmt *stmt) { BasicBlock *after_break = BasicBlock::Create(*llvm_context, "after_break", func); TI_ASSERT(current_while_after_loop); - auto cond = - builder->CreateICmpEQ(llvm_val[stmt->cond], tlctx->get_constant(0)); + auto *cond = builder->CreateIsNull(llvm_val[stmt->cond]); builder->CreateCondBr(cond, current_while_after_loop, after_break); builder->SetInsertPoint(after_break); } @@ -1309,7 +1319,7 @@ void TaskCodeGenLLVM::visit(AssertStmt *stmt) { std::vector args; args.emplace_back(get_runtime()); - args.emplace_back(llvm_val[stmt->cond]); + args.emplace_back(builder->CreateIsNotNull(llvm_val[stmt->cond])); args.emplace_back(builder->CreateGlobalStringPtr(stmt->text)); for (int i = 0; i < stmt->args.size(); i++) { @@ -2220,8 +2230,7 @@ void TaskCodeGenLLVM::create_offload_struct_for(OffloadedStmt *stmt) { // test whether the current voxel is active or not auto is_active = call(leaf_block, element.get("element"), "is_active", {builder->CreateLoad(loop_index_ty, loop_index)}); - is_active = - builder->CreateTrunc(is_active, llvm::Type::getInt1Ty(*llvm_context)); + is_active = builder->CreateIsNotNull(is_active); exec_cond = builder->CreateAnd(exec_cond, is_active); } diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 38348e4c1d727..f03e0fe0feddb 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1652,9 +1652,10 @@ class TaskCodegen : public IRVisitor { } void visit(IfStmt *if_stmt) override { - spirv::Value cond_v = ir_->query_value(if_stmt->cond->raw_name()); + spirv::Value cond_v = ir_->cast( + ir_->bool_type(), ir_->query_value(if_stmt->cond->raw_name())); spirv::Value cond = - ir_->ne(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_)); + ir_->ne(cond_v, ir_->cast(ir_->bool_type(), ir_->const_i32_zero_)); spirv::Label then_label = ir_->new_label(); spirv::Label merge_label = ir_->new_label(); spirv::Label else_label = ir_->new_label(); @@ -1776,9 +1777,10 @@ class TaskCodegen : public IRVisitor { } void visit(WhileControlStmt *stmt) override { - spirv::Value cond_v = ir_->query_value(stmt->cond->raw_name()); + spirv::Value cond_v = + ir_->cast(ir_->bool_type(), ir_->query_value(stmt->cond->raw_name())); spirv::Value cond = - ir_->eq(cond_v, ir_->cast(cond_v.stype, ir_->const_i32_zero_)); + ir_->eq(cond_v, ir_->cast(ir_->bool_type(), ir_->const_i32_zero_)); spirv::Label then_label = ir_->new_label(); spirv::Label merge_label = ir_->new_label(); diff --git a/taichi/runtime/llvm/llvm_context.cpp b/taichi/runtime/llvm/llvm_context.cpp index cef39f793b7e4..3156e0c72ccb6 100644 --- a/taichi/runtime/llvm/llvm_context.cpp +++ b/taichi/runtime/llvm/llvm_context.cpp @@ -692,6 +692,9 @@ llvm::Value *TaichiLLVMContext::get_constant(DataType dt, T t) { return llvm::ConstantFP::get(llvm::Type::getHalfTy(*ctx), (float32)t); } else if (dt->is_primitive(PrimitiveTypeID::f64)) { return llvm::ConstantFP::get(*ctx, llvm::APFloat((float64)t)); + } else if (dt->is_primitive(PrimitiveTypeID::u1)) { + return t ? llvm::ConstantInt::getTrue(*ctx) + : llvm::ConstantInt::getFalse(*ctx); } else if (is_integral(dt)) { if (is_signed(dt)) { return llvm::ConstantInt::get( @@ -721,7 +724,8 @@ llvm::Value *TaichiLLVMContext::get_constant(T t) { std::is_same_v) { return llvm::ConstantFP::get(*ctx, llvm::APFloat(t)); } else if (std::is_same_v) { - return llvm::ConstantInt::get(*ctx, llvm::APInt(1, (uint64)t, true)); + return t ? llvm::ConstantInt::getTrue(*ctx) + : llvm::ConstantInt::getFalse(*ctx); } else if (std::is_same_v || std::is_same_v) { return llvm::ConstantInt::get(*ctx, llvm::APInt(32, (uint64)t, true)); diff --git a/taichi/runtime/llvm/runtime_module/runtime.cpp b/taichi/runtime/llvm/runtime_module/runtime.cpp index 3a71575ae4243..460beb145e7e7 100644 --- a/taichi/runtime/llvm/runtime_module/runtime.cpp +++ b/taichi/runtime/llvm/runtime_module/runtime.cpp @@ -332,9 +332,9 @@ struct LLVMRuntime; constexpr bool enable_assert = true; -void taichi_assert(RuntimeContext *context, i32 test, const char *msg); -void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg); -#define TI_ASSERT_INFO(x, msg) taichi_assert(context, (int)(x), msg) +void taichi_assert(RuntimeContext *context, u1 test, const char *msg); +void taichi_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg); +#define TI_ASSERT_INFO(x, msg) taichi_assert(context, (u1)(x), msg) #define TI_ASSERT(x) TI_ASSERT_INFO(x, #x) void ___stubs___() { @@ -753,12 +753,12 @@ RUNTIME_STRUCT_FIELD(ListManager, num_elements); RUNTIME_STRUCT_FIELD(ListManager, max_num_elements_per_chunk); RUNTIME_STRUCT_FIELD(ListManager, element_size); -void taichi_assert(RuntimeContext *context, i32 test, const char *msg) { +void taichi_assert(RuntimeContext *context, u1 test, const char *msg) { taichi_assert_runtime(context->runtime, test, msg); } void taichi_assert_format(LLVMRuntime *runtime, - i32 test, + u1 test, const char *format, int num_arguments, uint64 *arguments) { @@ -808,7 +808,7 @@ void taichi_assert_format(LLVMRuntime *runtime, #endif } -void taichi_assert_runtime(LLVMRuntime *runtime, i32 test, const char *msg) { +void taichi_assert_runtime(LLVMRuntime *runtime, u1 test, const char *msg) { taichi_assert_format(runtime, test, msg, 0, nullptr); } diff --git a/tests/python/test_pow.py b/tests/python/test_pow.py index 86fd59575827c..2679520134310 100644 --- a/tests/python/test_pow.py +++ b/tests/python/test_pow.py @@ -62,25 +62,23 @@ def foo(x: dt, y: ti.template()): foo(10, -10) -# FIXME(zhantong): Uncomment this test after bool assertion is finished. -# @test_utils.test( -# debug=True, -# advanced_optimization=False, -# exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles], -# ) -# def test_ipow_negative_exp_i32(): -# _ipow_negative_exp(ti.i32) - - -# FIXME(zhantong): Uncomment this test after bool assertion is finished. -# @test_utils.test( -# debug=True, -# advanced_optimization=False, -# require=ti.extension.data64, -# exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles], -# ) -# def test_ipow_negative_exp_i64(): -# _ipow_negative_exp(ti.i64) +@test_utils.test( + debug=True, + advanced_optimization=False, + exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles], +) +def test_ipow_negative_exp_i32(): + _ipow_negative_exp(ti.i32) + + +@test_utils.test( + debug=True, + advanced_optimization=False, + require=ti.extension.data64, + exclude=[ti.vulkan, ti.metal, ti.opengl, ti.gles], +) +def test_ipow_negative_exp_i64(): + _ipow_negative_exp(ti.i64) def _test_pow_int_base_int_exp(dt_base, dt_exp):