Skip to content

Commit

Permalink
[Opt] Make merging casts int(int(x)) less aggressive (#7944)
Browse files Browse the repository at this point in the history
Fixes #7915 

### Brief Summary

<!--
copilot:summary
-->
### <samp>🤖 Generated by Copilot at ed6302b</samp>

Fix a bug in `alg_simp` that removed casts between signed and unsigned
integers. Add a test case in `test_optimization` to check the cast
simplification.

### Walkthrough

<!--
copilot:walkthrough
-->
### <samp>🤖 Generated by Copilot at ed6302b</samp>

* Fix a bug in algebraic simplification that incorrectly removed some
casts between signed and unsigned integers
([link](https://github.com/taichi-dev/taichi/pull/7944/files?diff=unified&w=0#diff-77d8ca8e4dc6081988bd6dddb74bb9a5485af28ce3e0b43bc06d123256695513L63-R66))
* Add a test case to verify the correctness of the cast simplification
after the bug fix
([link](https://github.com/taichi-dev/taichi/pull/7944/files?diff=unified&w=0#diff-b8b031f0789413acece482512df4af5b8419a2a2dea3624b26114bbb9b57d334R146-R155))
  • Loading branch information
ailzhang authored May 8, 2023
1 parent 1ee025b commit 975941d
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 2 deletions.
6 changes: 4 additions & 2 deletions taichi/transforms/alg_simp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ class AlgSimp : public BasicStmtVisitor {
data_type_bits(second_cast) <= data_type_bits(first_cast);
}
if (is_integral(first_cast)) {
// int(int(a))
return data_type_bits(second_cast) <= data_type_bits(first_cast);
// int(int(a)), note it's not always equivalent when signedness differ,
// see #7915
return data_type_bits(second_cast) <= data_type_bits(first_cast) &&
is_signed(second_cast) == is_signed(first_cast);
}
// int(float(a))
if (data_type_bits(second_cast) <= data_type_bits(first_cast) * 2) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,13 @@ def func():
for i in range(3):
for j in range(4):
assert mat[i, j] == i + 1


@test_utils.test()
def test_casts_int_uint():
@ti.kernel
def my_cast(x: ti.f32) -> ti.u32:
y = ti.floor(x, ti.i32)
return ti.cast(y, ti.u32)

assert my_cast(-1) == 4294967295

0 comments on commit 975941d

Please sign in to comment.