Skip to content

Commit

Permalink
[lang] [ir] Add logical and, logical or in ir
Browse files Browse the repository at this point in the history
ghstack-source-id: d50fedb188af7c6f9548a542aba6200e70d7a8f7
Pull Request resolved: #8008
  • Loading branch information
listerily committed May 16, 2023
1 parent da39780 commit d622f75
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 20 deletions.
2 changes: 2 additions & 0 deletions taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1147,6 +1147,8 @@ class TaskCodegen : public IRVisitor {
BINARY_OP_TO_SPIRV_LOGICAL(cmp_ge, ge)
BINARY_OP_TO_SPIRV_LOGICAL(cmp_eq, eq)
BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne)
BINARY_OP_TO_SPIRV_LOGICAL(logical_and, logical_and)
BINARY_OP_TO_SPIRV_LOGICAL(logical_or, logical_or)
#undef BINARY_OP_TO_SPIRV_LOGICAL

#define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \
Expand Down
10 changes: 10 additions & 0 deletions taichi/codegen/spirv/spirv_ir_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,16 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual);
DEFINE_BUILDER_CMP_UOP(eq, Equal);
DEFINE_BUILDER_CMP_UOP(ne, NotEqual);

#define DEFINE_BUILDER_LOGICAL_OP(_OpName, _Op) \
Value IRBuilder::_OpName(Value a, Value b) { \
TI_ASSERT(a.stype.id == b.stype.id); \
TI_ASSERT(is_integral(a.stype.dt)); \
return make_value(spv::OpLogical##_Op, t_bool_, a, b); \
}

DEFINE_BUILDER_LOGICAL_OP(logical_and, And);
DEFINE_BUILDER_LOGICAL_OP(logical_or, Or);

Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) {
TI_ASSERT(is_integral(base.stype.dt));
TI_ASSERT(is_integral(offset.stype.dt));
Expand Down
2 changes: 2 additions & 0 deletions taichi/codegen/spirv/spirv_ir_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ class IRBuilder {
Value le(Value a, Value b);
Value gt(Value a, Value b);
Value ge(Value a, Value b);
Value logical_and(Value a, Value b);
Value logical_or(Value a, Value b);
Value bit_field_extract(Value base, Value offset, Value count);
Value select(Value cond, Value a, Value b);
Value popcnt(Value x);
Expand Down
10 changes: 5 additions & 5 deletions taichi/ir/frontend_ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,8 @@ void BinaryOpExpression::type_check(const CompileConfig *config) {
if (binary_is_bitwise(type) && (!is_integral(lhs_type.get_element_type()) ||
!is_integral(rhs_type.get_element_type())))
error();
if (binary_is_logical(type) &&
(is_tensor_op || lhs_type != PrimitiveType::i32 ||
rhs_type != PrimitiveType::i32))
if (binary_is_logical(type) && !(is_integral(lhs_type.get_element_type()) &&
is_integral(rhs_type.get_element_type())))
error();
if (is_comparison(type) || binary_is_logical(type)) {
ret_type = make_dt(PrimitiveType::i32);
Expand Down Expand Up @@ -401,7 +400,8 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) {
// return;
auto lhs_stmt = flatten_rvalue(lhs, ctx);

if (binary_is_logical(type)) {
if (binary_is_logical(type) && !is_tensor(lhs->ret_type) &&
!is_tensor(rhs->ret_type)) {
auto result = ctx->push_back<AllocaStmt>(ret_type);
ctx->push_back<LocalStoreStmt>(result, lhs_stmt);
auto cond = ctx->push_back<LocalLoadStmt>(result);
Expand Down Expand Up @@ -540,7 +540,7 @@ void TernaryOpExpression::type_check(const CompileConfig *config) {
is_valid = false;
}

if (op1_type != PrimitiveType::i32) {
if (!is_integral(op1_type)) {
is_valid = false;
}
if (!op2_type->is<PrimitiveType>() || !op3_type->is<PrimitiveType>()) {
Expand Down
2 changes: 0 additions & 2 deletions taichi/transforms/constant_fold.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,6 @@ class ConstantFold : public BasicStmtVisitor {
#undef COMMA

case BinaryOpType::truediv:
case BinaryOpType::logical_or:
case BinaryOpType::logical_and:
TI_ERROR("{} should have been lowered.",
binary_op_type_name(stmt->op_type));
break;
Expand Down
13 changes: 0 additions & 13 deletions tests/python/test_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -1080,19 +1080,6 @@ def verify(x):
_test_field_and_ndarray(field, ndarray, func, verify)


@test_utils.test()
def test_unsupported_logical_operations():
@ti.kernel
def test():
x = ti.Vector([1, 0])
y = ti.Vector([1, 1])

z = x and y

with pytest.raises(TaichiTypeError, match=r"unsupported operand type\(s\) for "):
test()


@test_utils.test()
def test_vector_transpose():
@ti.kernel
Expand Down

0 comments on commit d622f75

Please sign in to comment.