From ed6302b45bded79e1952ad3367cffa9d3cafdf2f Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Sat, 6 May 2023 17:00:05 +0800 Subject: [PATCH] [Opt] Make merging casts int(int(x)) less aggressive --- taichi/transforms/alg_simp.cpp | 6 ++++-- tests/python/test_optimization.py | 10 ++++++++++ 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/taichi/transforms/alg_simp.cpp b/taichi/transforms/alg_simp.cpp index 792f865410a0d..44ff6e9408185 100644 --- a/taichi/transforms/alg_simp.cpp +++ b/taichi/transforms/alg_simp.cpp @@ -60,8 +60,10 @@ class AlgSimp : public BasicStmtVisitor { data_type_bits(second_cast) <= data_type_bits(first_cast); } if (is_integral(first_cast)) { - // int(int(a)) - return data_type_bits(second_cast) <= data_type_bits(first_cast); + // int(int(a)), note it's not always equivalent when signedness differ, + // see #7915 + return data_type_bits(second_cast) <= data_type_bits(first_cast) && + is_signed(second_cast) == is_signed(first_cast); } // int(float(a)) if (data_type_bits(second_cast) <= data_type_bits(first_cast) * 2) { diff --git a/tests/python/test_optimization.py b/tests/python/test_optimization.py index fa7147d7282c1..c966464647da4 100644 --- a/tests/python/test_optimization.py +++ b/tests/python/test_optimization.py @@ -143,3 +143,13 @@ def func(): for i in range(3): for j in range(4): assert mat[i, j] == i + 1 + + +@test_utils.test() +def test_casts_int_uint(): + @ti.kernel + def my_cast(x: ti.f32) -> ti.u32: + y = ti.floor(x, ti.i32) + return ti.cast(y, ti.u32) + + assert my_cast(-1) == 4294967295