From 22a099dbdb76a2fdd341f7c69c1e34bb77c1690e Mon Sep 17 00:00:00 2001 From: Wimaxs Date: Tue, 29 Mar 2022 22:52:18 +0800 Subject: [PATCH] [SIMT] Add ballot_sync warp intrinsics (#4641) --- python/taichi/lang/simt/warp.py | 8 +++++--- taichi/runtime/llvm/runtime.cpp | 8 ++++++++ tests/python/test_simt.py | 22 ++++++++++++++++++++-- 3 files changed, 33 insertions(+), 5 deletions(-) diff --git a/python/taichi/lang/simt/warp.py b/python/taichi/lang/simt/warp.py index 91e11894e8037..0c0d73a57b56c 100644 --- a/python/taichi/lang/simt/warp.py +++ b/python/taichi/lang/simt/warp.py @@ -17,9 +17,11 @@ def unique(): pass -def ballot(): - # TODO - pass +def ballot(predicate): + return expr.Expr( + _ti_core.insert_internal_func_call("cuda_ballot_i32", + expr.make_expr_group(predicate), + False)) def shfl_i32(mask, val, offset): diff --git a/taichi/runtime/llvm/runtime.cpp b/taichi/runtime/llvm/runtime.cpp index e90ad1a220521..25c329361a481 100644 --- a/taichi/runtime/llvm/runtime.cpp +++ b/taichi/runtime/llvm/runtime.cpp @@ -1040,6 +1040,14 @@ int32 cuda_ballot_sync(int32 mask, bool bit) { return 0; } +int32 cuda_ballot_i32(int32 predicate) { + return cuda_ballot_sync(UINT32_MAX, (bool)predicate); +} + +int32 cuda_ballot_sync_i32(u32 mask, int32 predicate) { + return cuda_ballot_sync(mask, (bool)predicate); +} + i32 cuda_match_any_sync_i32(i32 mask, i32 value) { return 0; } diff --git a/tests/python/test_simt.py b/tests/python/test_simt.py index ecb3aaeab79eb..58b1e0032786a 100644 --- a/tests/python/test_simt.py +++ b/tests/python/test_simt.py @@ -1,3 +1,5 @@ +import random + from pytest import approx import taichi as ti @@ -24,8 +26,24 @@ def test_unique(): @test_utils.test(arch=ti.cuda) def test_ballot(): - # TODO - pass + a = ti.field(dtype=ti.u32, shape=32) + b = ti.field(dtype=ti.u32, shape=32) + + @ti.kernel + def foo(): + ti.loop_config(block_dim=32) + for i in range(32): + a[i] = ti.simt.warp.ballot(b[i]) + + key = 0 + for i in range(32): + b[i] = i % 2 + key += b[i] * pow(2, i) + + foo() + + for i in range(32): + assert a[i] == key @test_utils.test(arch=ti.cuda)