From a709d325f9b4329bc908138a3a4c006210096c00 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Fri, 23 Sep 2022 18:08:53 +0800 Subject: [PATCH 1/2] [Bug] [lang] Fix augmented assign for sar --- python/taichi/lang/common_ops.py | 2 +- taichi/codegen/spirv/spirv_codegen.cpp | 1 - taichi/ir/stmt_op_types.h | 2 +- taichi/transforms/demote_operations.cpp | 4 ++-- tests/python/test_bit_operations.py | 18 +++++++++++++++--- 5 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py index 745ab26b93220..4c926642ccc9d 100644 --- a/python/taichi/lang/common_ops.py +++ b/python/taichi/lang/common_ops.py @@ -242,7 +242,7 @@ def __ilshift__(self, other): def __irshift__(self, other): if in_python_scope(): return NotImplemented - self._assign(ops.bit_shr(self, other)) + self._assign(ops.bit_sar(self, other)) return self def __ipow__(self, other): diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 2f549e86e943a..8b0a374f78197 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -841,7 +841,6 @@ class TaskCodegen : public IRVisitor { BINARY_OP_TO_SPIRV_BITWISE(bit_or, OpBitwiseOr) BINARY_OP_TO_SPIRV_BITWISE(bit_xor, OpBitwiseXor) BINARY_OP_TO_SPIRV_BITWISE(bit_shl, OpShiftLeftLogical) - BINARY_OP_TO_SPIRV_BITWISE(bit_shr, OpShiftRightLogical) // NOTE: `OpShiftRightArithmetic` will treat the first bit as sign bit even // it's the unsigned type else if (op_type == BinaryOpType::bit_sar) { diff --git a/taichi/ir/stmt_op_types.h b/taichi/ir/stmt_op_types.h index f2c9ee681b6c2..64198b2941266 100644 --- a/taichi/ir/stmt_op_types.h +++ b/taichi/ir/stmt_op_types.h @@ -38,7 +38,7 @@ enum class BinaryOpType : int { inline bool binary_is_bitwise(BinaryOpType t) { return t == BinaryOpType ::bit_and || t == BinaryOpType ::bit_or || t == BinaryOpType ::bit_xor || t == BinaryOpType ::bit_shl || - t == BinaryOpType ::bit_sar; + t == BinaryOpType ::bit_shr || t == BinaryOpType ::bit_sar; } inline bool binary_is_logical(BinaryOpType t) { diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index cc3a65ab2e345..48528ede3d5a3 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -118,8 +118,8 @@ class DemoteOperations : public BasicStmtVisitor { is_signed(lhs->element_type())) { // @ti.func // def bit_shr(a, b): - // signed_a = ti.cast(a, ti.uXX) - // shifted = ti.bit_sar(signed_a, b) + // unsigned_a = ti.cast(a, ti.uXX) + // shifted = ti.bit_sar(unsigned_a, b) // ret = ti.cast(shifted, ti.iXX) // return ret auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); diff --git a/tests/python/test_bit_operations.py b/tests/python/test_bit_operations.py index 5c11194dd69ce..c3ac248a1557e 100644 --- a/tests/python/test_bit_operations.py +++ b/tests/python/test_bit_operations.py @@ -13,8 +13,14 @@ def test_bit_shl(): def shl(a: ti.i32, b: ti.i32) -> ti.i32: return a << b + @ti.kernel + def shl_assign(a: ti.i32, b: ti.i32) -> ti.i32: + c = a + c <<= b + return c + for i in range(8): - assert shl(3, i) == 3 * 2**i + assert shl(3, i) == shl_assign(3, i) == 3 * 2**i @test_utils.test() @@ -23,14 +29,20 @@ def test_bit_sar(): def sar(a: ti.i32, b: ti.i32) -> ti.i32: return a >> b + @ti.kernel + def sar_assign(a: ti.i32, b: ti.i32) -> ti.i32: + c = a + c >>= b + return c + n = 8 test_num = 2**n neg_test_num = -test_num for i in range(n): - assert sar(test_num, i) == 2**(n - i) + assert sar(test_num, i) == sar_assign(test_num, i) == 2**(n - i) # for negative number for i in range(n): - assert sar(neg_test_num, i) == -2**(n - i) + assert sar(neg_test_num, i) == sar_assign(neg_test_num, i) == -2**(n - i) @test_utils.test() From abd58e15ba5c7603fe68871ab9904a8422e7042f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 23 Sep 2022 10:12:52 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/python/test_bit_operations.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/python/test_bit_operations.py b/tests/python/test_bit_operations.py index c3ac248a1557e..9cea78323aa09 100644 --- a/tests/python/test_bit_operations.py +++ b/tests/python/test_bit_operations.py @@ -42,7 +42,8 @@ def sar_assign(a: ti.i32, b: ti.i32) -> ti.i32: assert sar(test_num, i) == sar_assign(test_num, i) == 2**(n - i) # for negative number for i in range(n): - assert sar(neg_test_num, i) == sar_assign(neg_test_num, i) == -2**(n - i) + assert sar(neg_test_num, i) == sar_assign(neg_test_num, + i) == -2**(n - i) @test_utils.test()