Skip to content

Commit

Permalink
[bug] Fix abs on unsigned types (#8476)
Browse files Browse the repository at this point in the history
Issue: fixes #8467 

Remove abs on unsigned types

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
lin-hitonami and pre-commit-ci[bot] authored Feb 20, 2024
1 parent 4139722 commit 52b24f3
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 0 deletions.
10 changes: 10 additions & 0 deletions taichi/transforms/simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ class BasicBlockSimplify : public IRVisitor {
}
}

void visit(UnaryOpStmt *stmt) override {
if (stmt->op_type == UnaryOpType::abs) {
auto operand_type = stmt->operand->ret_type;
if (is_integral(operand_type) && is_unsigned(operand_type)) {
// abs(u) -> u
stmt->replace_usages_with(stmt->operand);
modifier.erase(stmt);
}
}
}
template <typename T>
static bool identical_vectors(const std::vector<T> &a,
const std::vector<T> &b) {
Expand Down
10 changes: 10 additions & 0 deletions tests/python/test_abs.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,13 @@ def foo(x: ti.i64) -> ti.i64:

for x in [-(2**40), 0, 2**40]:
assert foo(x) == abs(x)


@test_utils.test()
def test_abs_u32():
@ti.kernel
def foo(x: ti.u32) -> ti.u32:
return abs(x)

for x in [0, 2**20]:
assert foo(x) == abs(x)

0 comments on commit 52b24f3

Please sign in to comment.