From d622f7584ba4d4d9e0191b204e130b0e6ed2ecbe Mon Sep 17 00:00:00 2001 From: listerily Date: Tue, 16 May 2023 12:06:26 +0800 Subject: [PATCH] [lang] [ir] Add logical and, logical or in ir ghstack-source-id: d50fedb188af7c6f9548a542aba6200e70d7a8f7 Pull Request resolved: https://github.com/taichi-dev/taichi/pull/8008 --- taichi/codegen/spirv/spirv_codegen.cpp | 2 ++ taichi/codegen/spirv/spirv_ir_builder.cpp | 10 ++++++++++ taichi/codegen/spirv/spirv_ir_builder.h | 2 ++ taichi/ir/frontend_ir.cpp | 10 +++++----- taichi/transforms/constant_fold.cpp | 2 -- tests/python/test_matrix.py | 13 ------------- 6 files changed, 19 insertions(+), 20 deletions(-) diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index f03e0fe0feddbb..19f03bfb1846f6 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -1147,6 +1147,8 @@ class TaskCodegen : public IRVisitor { BINARY_OP_TO_SPIRV_LOGICAL(cmp_ge, ge) BINARY_OP_TO_SPIRV_LOGICAL(cmp_eq, eq) BINARY_OP_TO_SPIRV_LOGICAL(cmp_ne, ne) + BINARY_OP_TO_SPIRV_LOGICAL(logical_and, logical_and) + BINARY_OP_TO_SPIRV_LOGICAL(logical_or, logical_or) #undef BINARY_OP_TO_SPIRV_LOGICAL #define FLOAT_BINARY_OP_TO_SPIRV_FLOAT_FUNC(op, instruction, instruction_id, \ diff --git a/taichi/codegen/spirv/spirv_ir_builder.cpp b/taichi/codegen/spirv/spirv_ir_builder.cpp index 65f5cf159ffc12..9829cf0e8bb15f 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.cpp +++ b/taichi/codegen/spirv/spirv_ir_builder.cpp @@ -1107,6 +1107,16 @@ DEFINE_BUILDER_CMP_OP(ge, GreaterThanEqual); DEFINE_BUILDER_CMP_UOP(eq, Equal); DEFINE_BUILDER_CMP_UOP(ne, NotEqual); +#define DEFINE_BUILDER_LOGICAL_OP(_OpName, _Op) \ + Value IRBuilder::_OpName(Value a, Value b) { \ + TI_ASSERT(a.stype.id == b.stype.id); \ + TI_ASSERT(is_integral(a.stype.dt)); \ + return make_value(spv::OpLogical##_Op, t_bool_, a, b); \ + } + +DEFINE_BUILDER_LOGICAL_OP(logical_and, And); +DEFINE_BUILDER_LOGICAL_OP(logical_or, Or); + Value IRBuilder::bit_field_extract(Value base, Value offset, Value count) { TI_ASSERT(is_integral(base.stype.dt)); TI_ASSERT(is_integral(offset.stype.dt)); diff --git a/taichi/codegen/spirv/spirv_ir_builder.h b/taichi/codegen/spirv/spirv_ir_builder.h index 0fec88bc38d28c..8b1d50f0d139fc 100644 --- a/taichi/codegen/spirv/spirv_ir_builder.h +++ b/taichi/codegen/spirv/spirv_ir_builder.h @@ -472,6 +472,8 @@ class IRBuilder { Value le(Value a, Value b); Value gt(Value a, Value b); Value ge(Value a, Value b); + Value logical_and(Value a, Value b); + Value logical_or(Value a, Value b); Value bit_field_extract(Value base, Value offset, Value count); Value select(Value cond, Value a, Value b); Value popcnt(Value x); diff --git a/taichi/ir/frontend_ir.cpp b/taichi/ir/frontend_ir.cpp index 641c9077a25763..8427531c067c0c 100644 --- a/taichi/ir/frontend_ir.cpp +++ b/taichi/ir/frontend_ir.cpp @@ -359,9 +359,8 @@ void BinaryOpExpression::type_check(const CompileConfig *config) { if (binary_is_bitwise(type) && (!is_integral(lhs_type.get_element_type()) || !is_integral(rhs_type.get_element_type()))) error(); - if (binary_is_logical(type) && - (is_tensor_op || lhs_type != PrimitiveType::i32 || - rhs_type != PrimitiveType::i32)) + if (binary_is_logical(type) && !(is_integral(lhs_type.get_element_type()) && + is_integral(rhs_type.get_element_type()))) error(); if (is_comparison(type) || binary_is_logical(type)) { ret_type = make_dt(PrimitiveType::i32); @@ -401,7 +400,8 @@ void BinaryOpExpression::flatten(FlattenContext *ctx) { // return; auto lhs_stmt = flatten_rvalue(lhs, ctx); - if (binary_is_logical(type)) { + if (binary_is_logical(type) && !is_tensor(lhs->ret_type) && + !is_tensor(rhs->ret_type)) { auto result = ctx->push_back(ret_type); ctx->push_back(result, lhs_stmt); auto cond = ctx->push_back(result); @@ -540,7 +540,7 @@ void TernaryOpExpression::type_check(const CompileConfig *config) { is_valid = false; } - if (op1_type != PrimitiveType::i32) { + if (!is_integral(op1_type)) { is_valid = false; } if (!op2_type->is() || !op3_type->is()) { diff --git a/taichi/transforms/constant_fold.cpp b/taichi/transforms/constant_fold.cpp index 051e3afb30ff76..83e48ab842a554 100644 --- a/taichi/transforms/constant_fold.cpp +++ b/taichi/transforms/constant_fold.cpp @@ -136,8 +136,6 @@ class ConstantFold : public BasicStmtVisitor { #undef COMMA case BinaryOpType::truediv: - case BinaryOpType::logical_or: - case BinaryOpType::logical_and: TI_ERROR("{} should have been lowered.", binary_op_type_name(stmt->op_type)); break; diff --git a/tests/python/test_matrix.py b/tests/python/test_matrix.py index ea666a12a8f268..b7cc610b8ba34a 100644 --- a/tests/python/test_matrix.py +++ b/tests/python/test_matrix.py @@ -1080,19 +1080,6 @@ def verify(x): _test_field_and_ndarray(field, ndarray, func, verify) -@test_utils.test() -def test_unsupported_logical_operations(): - @ti.kernel - def test(): - x = ti.Vector([1, 0]) - y = ti.Vector([1, 1]) - - z = x and y - - with pytest.raises(TaichiTypeError, match=r"unsupported operand type\(s\) for "): - test() - - @test_utils.test() def test_vector_transpose(): @ti.kernel