From aeb21a953f8b3f6018dd9aefc8f84ce19f03e3a1 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 8 Sep 2022 17:03:21 +0800 Subject: [PATCH 1/2] [Lang] Support abs(i64) --- taichi/codegen/cc/runtime/base.h | 4 ++++ taichi/codegen/cuda/codegen_cuda.cpp | 2 ++ taichi/codegen/llvm/codegen_llvm.cpp | 2 ++ taichi/runtime/llvm/runtime_module/runtime.cpp | 12 ++++++------ tests/python/test_abs.py | 10 ++++++++++ 5 files changed, 24 insertions(+), 6 deletions(-) diff --git a/taichi/codegen/cc/runtime/base.h b/taichi/codegen/cc/runtime/base.h index 1e9be3fdda950..721a204284900 100644 --- a/taichi/codegen/cc/runtime/base.h +++ b/taichi/codegen/cc/runtime/base.h @@ -65,6 +65,10 @@ static inline Ti_f64 Ti_rsqrt(Ti_f64 x) { return 1 / sqrt(x); } +static inline Ti_i64 Ti_llabs(Ti_i64 x) { + return x >= 0 ? x : -x; +} + ) "\n" STR( static inline Ti_i32 Ti_rand_i32(void) { diff --git a/taichi/codegen/cuda/codegen_cuda.cpp b/taichi/codegen/cuda/codegen_cuda.cpp index 3572e1319f383..4f49e898cd390 100644 --- a/taichi/codegen/cuda/codegen_cuda.cpp +++ b/taichi/codegen/cuda/codegen_cuda.cpp @@ -148,6 +148,8 @@ class TaskCodeGenCUDA : public TaskCodeGenLLVM { llvm_val[stmt] = create_call("__nv_fabs", input); } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { llvm_val[stmt] = create_call("__nv_abs", input); + } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) { + llvm_val[stmt] = create_call("__nv_llabs", input); } else { TI_NOT_IMPLEMENTED } diff --git a/taichi/codegen/llvm/codegen_llvm.cpp b/taichi/codegen/llvm/codegen_llvm.cpp index a6d1d84961a4d..92c7d65141cbe 100644 --- a/taichi/codegen/llvm/codegen_llvm.cpp +++ b/taichi/codegen/llvm/codegen_llvm.cpp @@ -209,6 +209,8 @@ void TaskCodeGenLLVM::emit_extra_unary(UnaryOpStmt *stmt) { llvm_val[stmt] = create_call(#x "_f64", input); \ } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i32)) { \ llvm_val[stmt] = create_call(#x "_i32", input); \ + } else if (input_taichi_type->is_primitive(PrimitiveTypeID::i64)) { \ + llvm_val[stmt] = create_call(#x "_i64", input); \ } else { \ TI_NOT_IMPLEMENTED \ } \ diff --git a/taichi/runtime/llvm/runtime_module/runtime.cpp b/taichi/runtime/llvm/runtime_module/runtime.cpp index 1dceaff0a8a42..229a106959d3e 100644 --- a/taichi/runtime/llvm/runtime_module/runtime.cpp +++ b/taichi/runtime/llvm/runtime_module/runtime.cpp @@ -221,12 +221,12 @@ DEFINE_UNARY_REAL_FUNC(sin) DEFINE_FAST_POW(i32) DEFINE_FAST_POW(i64) -int abs_i32(int a) { - if (a > 0) { - return a; - } else { - return -a; - } +i32 abs_i32(i32 a) { + return a >= 0 ? a : -a; +} + +i64 abs_i64(i64 a) { + return a >= 0 ? a : -a; } i32 floordiv_i32(i32 a, i32 b) { diff --git a/tests/python/test_abs.py b/tests/python/test_abs.py index 510cd37f8ec03..1f5c376db1d97 100644 --- a/tests/python/test_abs.py +++ b/tests/python/test_abs.py @@ -68,3 +68,13 @@ def sgn(x): for i in range(N): assert x[i] == abs(y[i]) assert x.dual[i] == sgn(y[i]) + + +@test_utils.test(require=ti.extension.data64) +def test_abs_i64(): + @ti.kernel + def foo(x: ti.i64) -> ti.i64: + return abs(x) + + for x in [-2**40, 0, 2**40]: + assert foo(x) == abs(x) From ae5c434b92e7a4674b97c4858de0a86adfe905d8 Mon Sep 17 00:00:00 2001 From: Yi Xu Date: Thu, 8 Sep 2022 17:10:36 +0800 Subject: [PATCH 2/2] Remove redundant line --- taichi/codegen/cc/runtime/base.h | 1 - 1 file changed, 1 deletion(-) diff --git a/taichi/codegen/cc/runtime/base.h b/taichi/codegen/cc/runtime/base.h index 721a204284900..2210729eb8646 100644 --- a/taichi/codegen/cc/runtime/base.h +++ b/taichi/codegen/cc/runtime/base.h @@ -64,7 +64,6 @@ static inline Ti_f32 Ti_rsqrtf(Ti_f32 x) { static inline Ti_f64 Ti_rsqrt(Ti_f64 x) { return 1 / sqrt(x); } - static inline Ti_i64 Ti_llabs(Ti_i64 x) { return x >= 0 ? x : -x; }