Skip to content

Commit

Permalink
[lang] [ir] Add logical and, logical or in ir
Browse files Browse the repository at this point in the history
ghstack-source-id: dfb80971eaab80a4f00eb4fe0a71a4b2848c3daa
Pull Request resolved: #8008
  • Loading branch information
listerily committed May 15, 2023
1 parent f4e0d94 commit 2f0a2fa
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 8 deletions.
6 changes: 6 additions & 0 deletions taichi/codegen/llvm/codegen_llvm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,12 @@ void TaskCodeGenLLVM::visit(BinaryOpStmt *stmt) {
} else if (op == BinaryOpType::mod) {
llvm_val[stmt] =
builder->CreateSRem(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::logical_and) {
llvm_val[stmt] =
builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::logical_or) {
llvm_val[stmt] =
builder->CreateOr(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
} else if (op == BinaryOpType::bit_and) {
llvm_val[stmt] =
builder->CreateAnd(llvm_val[stmt->lhs], llvm_val[stmt->rhs]);
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,8 @@ class TaskCodegen : public IRVisitor {
BINARY_OP_TO_SPIRV_LOGICAL(cmp_ge, ge)
BINARY_OP_TO_SPIRV_LOGICAL(cmp_eq, eq)
BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne)
BINARY_OP_TO_SPIRV_LOGICAL(logical_and, logical_and)
BINARY_OP_TO_SPIRV_LOGICAL(logical_or, logical_or)
#undef BINARY_OP_TO_SPIRV_LOGICAL

#define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \
Expand Down
10 changes: 10 additions & 0 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,16 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
DEFINE_BUILDER_CMP_UOP(eq, Equal);
DEFINE_BUILDER_CMP_UOP(ne, NotEqual);

#define DEFINE_BUILDER_LOGICAL_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TI_ASSERT(a.stype.id == b.stype.id); \
TI_ASSERT(a.stype.dt->is_primitive(PrimitiveTypeID::u1)); \
return make_value(spv::OpLogical##_Op, t_bool_, a, b); \
}

DEFINE_BUILDER_LOGICAL_OP(logical_and, And);
DEFINE_BUILDER_LOGICAL_OP(logical_or, Or);

Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) {
TI_ASSERT(is_integral(base.stype.dt));
TI_ASSERT(is_integral(offset.stype.dt));
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/spirv/spirv_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ class IRBuilder {
Value le(Value a, Value b);
Value gt(Value a, Value b);
Value ge(Value a, Value b);
Value logical_and(Value a, Value b);
Value logical_or(Value a, Value b);
Value bit_field_extract(Value base, Value offset, Value count);
Value select(Value cond, Value a, Value b);
Value popcnt(Value x);
Expand Down
9 changes: 5 additions & 4 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -357,8 +357,8 @@ void BinaryOpExpression::type_check(const CompileConfig *config) {
!is_integral(rhs_type.get_element_type())))
error();
if (binary_is_logical(type) &&
(is_tensor_op || lhs_type != PrimitiveType::i32 ||
rhs_type != PrimitiveType::i32))
!(is_integral(lhs_type.get_element_type()) &&
is_integral(rhs_type.get_element_type())))
error();
if (is_comparison(type) || binary_is_logical(type)) {
ret_type = make_dt(PrimitiveType::i32);
Expand Down Expand Up @@ -398,7 +398,8 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
// return;
auto lhs_stmt = flatten_rvalue(lhs, ctx);

if (binary_is_logical(type)) {
if (binary_is_logical(type) && !is_tensor(lhs->ret_type) &&
!is_tensor(rhs->ret_type)) {
auto result = ctx->push_back<AllocaStmt>(ret_type);
ctx->push_back<LocalStoreStmt>(result, lhs_stmt);
auto cond = ctx->push_back<LocalLoadStmt>(result);
Expand Down Expand Up @@ -537,7 +538,7 @@ void TernaryOpExpression::type_check(const CompileConfig *config) {
is_valid = false;
}

if (op1_type != PrimitiveType::i32) {
if (!is_integral(op1_type)) {
is_valid = false;
}
if (!op2_type->is<PrimitiveType>() || !op3_type->is<PrimitiveType>()) {
Expand Down
8 changes: 4 additions & 4 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ class DemoteOperations : public BasicStmtVisitor {
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, lhs, zero.get());
auto cond3 =
Stmt::make<BinaryOpStmt>(BinaryOpType::cmp_ne, rhs_mul_ret.get(), lhs);
auto cond12 = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, cond1.get(),
cond2.get());
auto cond = Stmt::make<BinaryOpStmt>(BinaryOpType::bit_and, cond12.get(),
cond3.get());
auto cond12 = Stmt::make<BinaryOpStmt>(BinaryOpType::logical_and,
cond1.get(), cond2.get());
auto cond = Stmt::make<BinaryOpStmt>(BinaryOpType::logical_and,
cond12.get(), cond3.get());
auto real_ret =
Stmt::make<BinaryOpStmt>(BinaryOpType::sub, ret.get(), cond.get());
modifier.insert_before(stmt, std::move(ret));
Expand Down

0 comments on commit 2f0a2fa

Please sign in to comment.