diff --git a/python/taichi/lang/common_ops.py b/python/taichi/lang/common_ops.py index 745ab26b93220..4c926642ccc9d 100644 --- a/python/taichi/lang/common_ops.py +++ b/python/taichi/lang/common_ops.py @@ -242,7 +242,7 @@ def __ilshift__(self, other): def __irshift__(self, other): if in_python_scope(): return NotImplemented - self._assign(ops.bit_shr(self, other)) + self._assign(ops.bit_sar(self, other)) return self def __ipow__(self, other): diff --git a/taichi/codegen/spirv/spirv_codegen.cpp b/taichi/codegen/spirv/spirv_codegen.cpp index 08f3cce1459da..8dbfa71a70b00 100644 --- a/taichi/codegen/spirv/spirv_codegen.cpp +++ b/taichi/codegen/spirv/spirv_codegen.cpp @@ -840,7 +840,6 @@ class TaskCodegen : public IRVisitor { BINARY_OP_TO_SPIRV_BITWISE(bit_or, OpBitwiseOr) BINARY_OP_TO_SPIRV_BITWISE(bit_xor, OpBitwiseXor) BINARY_OP_TO_SPIRV_BITWISE(bit_shl, OpShiftLeftLogical) - BINARY_OP_TO_SPIRV_BITWISE(bit_shr, OpShiftRightLogical) // NOTE: `OpShiftRightArithmetic` will treat the first bit as sign bit even // it's the unsigned type else if (op_type == BinaryOpType::bit_sar) { diff --git a/taichi/ir/stmt_op_types.h b/taichi/ir/stmt_op_types.h index 30aa7a6425b92..761eb48d2afef 100644 --- a/taichi/ir/stmt_op_types.h +++ b/taichi/ir/stmt_op_types.h @@ -37,7 +37,7 @@ enum class BinaryOpType : int { inline bool binary_is_bitwise(BinaryOpType t) { return t == BinaryOpType ::bit_and || t == BinaryOpType ::bit_or || t == BinaryOpType ::bit_xor || t == BinaryOpType ::bit_shl || - t == BinaryOpType ::bit_sar; + t == BinaryOpType ::bit_shr || t == BinaryOpType ::bit_sar; } inline bool binary_is_logical(BinaryOpType t) { diff --git a/taichi/transforms/demote_operations.cpp b/taichi/transforms/demote_operations.cpp index d81a30b0281aa..5c3118d85277d 100644 --- a/taichi/transforms/demote_operations.cpp +++ b/taichi/transforms/demote_operations.cpp @@ -174,8 +174,8 @@ class DemoteOperations : public BasicStmtVisitor { is_signed(lhs->element_type())) { // @ti.func // def bit_shr(a, b): - // signed_a = ti.cast(a, ti.uXX) - // shifted = ti.bit_sar(signed_a, b) + // unsigned_a = ti.cast(a, ti.uXX) + // shifted = ti.bit_sar(unsigned_a, b) // ret = ti.cast(shifted, ti.iXX) // return ret auto unsigned_cast = Stmt::make(UnaryOpType::cast_bits, lhs); diff --git a/tests/python/test_bit_operations.py b/tests/python/test_bit_operations.py index 5c11194dd69ce..9cea78323aa09 100644 --- a/tests/python/test_bit_operations.py +++ b/tests/python/test_bit_operations.py @@ -13,8 +13,14 @@ def test_bit_shl(): def shl(a: ti.i32, b: ti.i32) -> ti.i32: return a << b + @ti.kernel + def shl_assign(a: ti.i32, b: ti.i32) -> ti.i32: + c = a + c <<= b + return c + for i in range(8): - assert shl(3, i) == 3 * 2**i + assert shl(3, i) == shl_assign(3, i) == 3 * 2**i @test_utils.test() @@ -23,14 +29,21 @@ def test_bit_sar(): def sar(a: ti.i32, b: ti.i32) -> ti.i32: return a >> b + @ti.kernel + def sar_assign(a: ti.i32, b: ti.i32) -> ti.i32: + c = a + c >>= b + return c + n = 8 test_num = 2**n neg_test_num = -test_num for i in range(n): - assert sar(test_num, i) == 2**(n - i) + assert sar(test_num, i) == sar_assign(test_num, i) == 2**(n - i) # for negative number for i in range(n): - assert sar(neg_test_num, i) == -2**(n - i) + assert sar(neg_test_num, i) == sar_assign(neg_test_num, + i) == -2**(n - i) @test_utils.test()