From 52b24f3e09c093610b1ecf69b5e33cbc66b7bd6d Mon Sep 17 00:00:00 2001 From: Lin Jiang Date: Tue, 20 Feb 2024 10:05:12 +0800 Subject: [PATCH] [bug] Fix abs on unsigned types (#8476) Issue: fixes #8467 Remove abs on unsigned types --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- taichi/transforms/simplify.cpp | 10 ++++++++++ tests/python/test_abs.py | 10 ++++++++++ 2 files changed, 20 insertions(+) diff --git a/taichi/transforms/simplify.cpp b/taichi/transforms/simplify.cpp index 8890394637b36..550ea8220303b 100644 --- a/taichi/transforms/simplify.cpp +++ b/taichi/transforms/simplify.cpp @@ -133,6 +133,16 @@ class BasicBlockSimplify : public IRVisitor { } } + void visit(UnaryOpStmt *stmt) override { + if (stmt->op_type == UnaryOpType::abs) { + auto operand_type = stmt->operand->ret_type; + if (is_integral(operand_type) && is_unsigned(operand_type)) { + // abs(u) -> u + stmt->replace_usages_with(stmt->operand); + modifier.erase(stmt); + } + } + } template static bool identical_vectors(const std::vector &a, const std::vector &b) { diff --git a/tests/python/test_abs.py b/tests/python/test_abs.py index 8ff4bc35aeddd..ab86b392507f6 100644 --- a/tests/python/test_abs.py +++ b/tests/python/test_abs.py @@ -78,3 +78,13 @@ def foo(x: ti.i64) -> ti.i64: for x in [-(2**40), 0, 2**40]: assert foo(x) == abs(x) + + +@test_utils.test() +def test_abs_u32(): + @ti.kernel + def foo(x: ti.u32) -> ti.u32: + return abs(x) + + for x in [0, 2**20]: + assert foo(x) == abs(x)