diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index 829b93399e7e1a..71e22bed5f4094 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -622,6 +622,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) { } else if (op == BinaryOpType::mod) { llvm_val[stmt] = builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } else if (op == BinaryOpType::logical_and) { + llvm_val[stmt] = + builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); + } else if (op == BinaryOpType::logical_or) { + llvm_val[stmt] = + builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); } else if (op == BinaryOpType::bit_and) { llvm_val[stmt] = builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]); diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index f03e0fe0feddbb..19f03bfb1846f6 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1147,6 +1147,8 @@ class TaskCodegen : public IRVisitor { BINARY_OP_TO_SPIRV_LOGICAL(cmp_ge, ge) BINARY_OP_TO_SPIRV_LOGICAL(cmp_eq, eq) BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne) + BINARY_OP_TO_SPIRV_LOGICAL(logical_and, logical_and) + BINARY_OP_TO_SPIRV_LOGICAL(logical_or, logical_or) #undef BINARY_OP_TO_SPIRV_LOGICAL #define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \ diff --git a/taichi/codegen/spirv/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp index 65f5cf159ffc12..1ea8fef2f75814 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -1107,6 +1107,16 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual); DEFINE_BUILDER_CMP_UOP(eq, Equal); DEFINE_BUILDER_CMP_UOP(ne, NotEqual); +#define DEFINE_BUILDER_LOGICAL_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TI_ASSERT(a.stype.id == b.stype.id); \ + TI_ASSERT(a.stype.dt->is_primitive(PrimitiveTypeID::u1)); \ + return make_value(spv::OpLogical##_Op, t_bool_, a, b); \ + } + +DEFINE_BUILDER_LOGICAL_OP(logical_and, And); +DEFINE_BUILDER_LOGICAL_OP(logical_or, Or); + Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) { TI_ASSERT(is_integral(base.stype.dt)); TI_ASSERT(is_integral(offset.stype.dt)); diff --git a/taichi/codegen/spirv/spirv_ir_builder.h b/taichi/codegen/spirv/spirv_ir_builder.h index 0fec88bc38d28c..8b1d50f0d139fc 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.h +++ b/taichi/codegen/spirv/spirv_ir_builder.h @@ -472,6 +472,8 @@ class IRBuilder { Value le(Value a, Value b); Value gt(Value a, Value b); Value ge(Value a, Value b); + Value logical_and(Value a, Value b); + Value logical_or(Value a, Value b); Value bit_field_extract(Value base, Value offset, Value count); Value select(Value cond, Value a, Value b); Value popcnt(Value x); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index ae597e7199ebd6..ee54fcd77d3b55 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -357,8 +357,8 @@ void BinaryOpExpression::type_check(const CompileConfig *config) { !is_integral(rhs_type.get_element_type()))) error(); if (binary_is_logical(type) && - (is_tensor_op || lhs_type != PrimitiveType::i32 || - rhs_type != PrimitiveType::i32)) + !(is_integral(lhs_type.get_element_type()) && + is_integral(rhs_type.get_element_type()))) error(); if (is_comparison(type) || binary_is_logical(type)) { ret_type = make_dt(PrimitiveType::i32); @@ -398,7 +398,8 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { // return; auto lhs_stmt = flatten_rvalue(lhs, ctx); - if (binary_is_logical(type)) { + if (binary_is_logical(type) && !is_tensor(lhs->ret_type) && + !is_tensor(rhs->ret_type)) { auto result = ctx->push_back(ret_type); ctx->push_back(result, lhs_stmt); auto cond = ctx->push_back(result); @@ -537,7 +538,7 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { is_valid = false; } - if (op1_type != PrimitiveType::i32) { + if (!is_integral(op1_type)) { is_valid = false; } if (!op2_type->is() || !op3_type->is()) { diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index 30264bbadcbfd1..5846d6e24b7c7c 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -33,10 +33,10 @@ class DemoteOperations : public BasicStmtVisitor { Stmt::make(BinaryOpType::cmp_ne, lhs, zero.get()); auto cond3 = Stmt::make(BinaryOpType::cmp_ne, rhs_mul_ret.get(), lhs); - auto cond12 = Stmt::make(BinaryOpType::bit_and, cond1.get(), - cond2.get()); - auto cond = Stmt::make(BinaryOpType::bit_and, cond12.get(), - cond3.get()); + auto cond12 = Stmt::make(BinaryOpType::logical_and, + cond1.get(), cond2.get()); + auto cond = Stmt::make(BinaryOpType::logical_and, + cond12.get(), cond3.get()); auto real_ret = Stmt::make(BinaryOpType::sub, ret.get(), cond.get()); modifier.insert_before(stmt, std::move(ret));