diff --git a/python/taichi/tools/__init__.py b/python/taichi/tools/__init__.py index 0c04f6c81c339..ab64ab5cfb86a 100644 --- a/python/taichi/tools/__init__.py +++ b/python/taichi/tools/__init__.py @@ -1,26 +1,14 @@ from .image import imdisplay, imread, imresize, imshow, imwrite from .np2ply import PLYWriter +from .sort import parallel_sort from .util import * # Don't import taichi_logo here which will cause circular import. # If you need it, just import from taichi.tools.patterns from .video import VideoManager __all__ = [ - 'PLYWriter', - 'VideoManager', - 'imdisplay', - 'imread', - 'imresize', - 'imshow', - 'imwrite', - 'deprecated', - 'warning', - 'dump_dot', - 'dot_to_pdf', - 'obsolete', - 'get_kernel_stats', - 'get_traceback', - 'set_gdb_trigger', - 'print_profile_info', - 'clear_profile_info', + 'PLYWriter', 'VideoManager', 'imdisplay', 'imread', 'imresize', 'imshow', + 'imwrite', 'deprecated', 'warning', 'dump_dot', 'dot_to_pdf', 'obsolete', + 'get_kernel_stats', 'get_traceback', 'set_gdb_trigger', + 'print_profile_info', 'clear_profile_info', 'parallel_sort' ] diff --git a/python/taichi/tools/sort.py b/python/taichi/tools/sort.py new file mode 100644 index 0000000000000..283f9e0b1e5c4 --- /dev/null +++ b/python/taichi/tools/sort.py @@ -0,0 +1,41 @@ +import taichi as ti + + +# Odd-even merge sort +# References: +# https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting +# https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort +def parallel_sort(keys, values=None): + N = keys.shape[0] + + @ti.kernel + def sort_stage(keys: ti.template(), use_values: int, values: ti.template(), + N: int, p: int, k: int, invocations: int): + for inv in range(invocations): + j = k % p + inv * 2 * k + for i in range(0, min(k, N - j - k)): + a = i + j + b = i + j + k + if int(a / (p * 2)) == int(b / (p * 2)): + key_a = keys[a] + key_b = keys[b] + if key_a > key_b: + keys[a] = key_b + keys[b] = key_a + if use_values != 0: + temp = values[a] + values[a] = values[b] + values[b] = temp + + p = 1 + while p < N: + k = p + while k >= 1: + invocations = int((N - k - k % p) / (2 * k)) + 1 + if values is None: + sort_stage(keys, 0, keys, N, p, k, invocations) + else: + sort_stage(keys, 1, values, N, p, k, invocations) + ti.sync() + k = int(k / 2) + p = int(p * 2) diff --git a/tests/python/test_sort.py b/tests/python/test_sort.py new file mode 100644 index 0000000000000..615a60886afbb --- /dev/null +++ b/tests/python/test_sort.py @@ -0,0 +1,32 @@ +import taichi as ti + + +@ti.test(exclude=[ti.cc]) +def test_sort(): + def test_sort_for_dtype(dtype, N): + keys = ti.field(dtype, N) + values = ti.field(dtype, N) + + @ti.kernel + def fill(): + for i in keys: + keys[i] = ti.random() * N + values[i] = keys[i] + + fill() + ti.parallel_sort(keys, values) + + keys_host = keys.to_numpy() + values_host = values.to_numpy() + + for i in range(N): + if i < N - 1: + assert keys_host[i] <= keys_host[i + 1] + assert keys_host[i] == values_host[i] + + test_sort_for_dtype(ti.i32, 1) + test_sort_for_dtype(ti.i32, 256) + test_sort_for_dtype(ti.i32, 100001) + test_sort_for_dtype(ti.f32, 1) + test_sort_for_dtype(ti.f32, 256) + test_sort_for_dtype(ti.f32, 100001)