Skip to content

Commit

Permalink
[Bug] [lang] Fix augmented assign for sar
Browse files Browse the repository at this point in the history
  • Loading branch information
strongoier committed Sep 23, 2022
1 parent c024865 commit a709d32
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 8 deletions.
2 changes: 1 addition & 1 deletion python/taichi/lang/common_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ def __ilshift__(self, other):
def __irshift__(self, other):
if in_python_scope():
return NotImplemented
self._assign(ops.bit_shr(self, other))
self._assign(ops.bit_sar(self, other))
return self

def __ipow__(self, other):
Expand Down
1 change: 0 additions & 1 deletion taichi/codegen/spirv/spirv_codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -841,7 +841,6 @@ class TaskCodegen : public IRVisitor {
BINARY_OP_TO_SPIRV_BITWISE(bit_or, OpBitwiseOr)
BINARY_OP_TO_SPIRV_BITWISE(bit_xor, OpBitwiseXor)
BINARY_OP_TO_SPIRV_BITWISE(bit_shl, OpShiftLeftLogical)
BINARY_OP_TO_SPIRV_BITWISE(bit_shr, OpShiftRightLogical)
// NOTE: `OpShiftRightArithmetic` will treat the first bit as sign bit even
// it's the unsigned type
else if (op_type == BinaryOpType::bit_sar) {
Expand Down
2 changes: 1 addition & 1 deletion taichi/ir/stmt_op_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ enum class BinaryOpType : int {
inline bool binary_is_bitwise(BinaryOpType t) {
return t == BinaryOpType ::bit_and || t == BinaryOpType ::bit_or ||
t == BinaryOpType ::bit_xor || t == BinaryOpType ::bit_shl ||
t == BinaryOpType ::bit_sar;
t == BinaryOpType ::bit_shr || t == BinaryOpType ::bit_sar;
}

inline bool binary_is_logical(BinaryOpType t) {
Expand Down
4 changes: 2 additions & 2 deletions taichi/transforms/demote_operations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ class DemoteOperations : public BasicStmtVisitor {
is_signed(lhs->element_type())) {
// @ti.func
// def bit_shr(a, b):
// signed_a = ti.cast(a, ti.uXX)
// shifted = ti.bit_sar(signed_a, b)
// unsigned_a = ti.cast(a, ti.uXX)
// shifted = ti.bit_sar(unsigned_a, b)
// ret = ti.cast(shifted, ti.iXX)
// return ret
auto unsigned_cast = Stmt::make<UnaryOpStmt>(UnaryOpType::cast_bits, lhs);
Expand Down
18 changes: 15 additions & 3 deletions tests/python/test_bit_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,14 @@ def test_bit_shl():
def shl(a: ti.i32, b: ti.i32) -> ti.i32:
return a << b

@ti.kernel
def shl_assign(a: ti.i32, b: ti.i32) -> ti.i32:
c = a
c <<= b
return c

for i in range(8):
assert shl(3, i) == 3 * 2**i
assert shl(3, i) == shl_assign(3, i) == 3 * 2**i


@test_utils.test()
Expand All @@ -23,14 +29,20 @@ def test_bit_sar():
def sar(a: ti.i32, b: ti.i32) -> ti.i32:
return a >> b

@ti.kernel
def sar_assign(a: ti.i32, b: ti.i32) -> ti.i32:
c = a
c >>= b
return c

n = 8
test_num = 2**n
neg_test_num = -test_num
for i in range(n):
assert sar(test_num, i) == 2**(n - i)
assert sar(test_num, i) == sar_assign(test_num, i) == 2**(n - i)
# for negative number
for i in range(n):
assert sar(neg_test_num, i) == -2**(n - i)
assert sar(neg_test_num, i) == sar_assign(neg_test_num, i) == -2**(n - i)


@test_utils.test()
Expand Down

0 comments on commit a709d32

Please sign in to comment.