From 128335471b05c235368fba5a89f2d9f5085d1a67 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Tue, 14 Dec 2021 00:30:45 +0000 Subject: [PATCH 1/5] Add parallel sort utility --- python/taichi/tools/__init__.py | 22 +++++----------------- python/taichi/tools/sort.py | 29 +++++++++++++++++++++++++++++ tests/python/test_sort.py | 27 +++++++++++++++++++++++++++ 3 files changed, 61 insertions(+), 17 deletions(-) create mode 100644 python/taichi/tools/sort.py create mode 100644 tests/python/test_sort.py diff --git a/python/taichi/tools/__init__.py b/python/taichi/tools/__init__.py index 0c04f6c81c339..ab64ab5cfb86a 100644 --- a/python/taichi/tools/__init__.py +++ b/python/taichi/tools/__init__.py @@ -1,26 +1,14 @@ from .image import imdisplay, imread, imresize, imshow, imwrite from .np2ply import PLYWriter +from .sort import parallel_sort from .util import * # Don't import taichi_logo here which will cause circular import. # If you need it, just import from taichi.tools.patterns from .video import VideoManager __all__ = [ - 'PLYWriter', - 'VideoManager', - 'imdisplay', - 'imread', - 'imresize', - 'imshow', - 'imwrite', - 'deprecated', - 'warning', - 'dump_dot', - 'dot_to_pdf', - 'obsolete', - 'get_kernel_stats', - 'get_traceback', - 'set_gdb_trigger', - 'print_profile_info', - 'clear_profile_info', + 'PLYWriter', 'VideoManager', 'imdisplay', 'imread', 'imresize', 'imshow', + 'imwrite', 'deprecated', 'warning', 'dump_dot', 'dot_to_pdf', 'obsolete', + 'get_kernel_stats', 'get_traceback', 'set_gdb_trigger', + 'print_profile_info', 'clear_profile_info', 'parallel_sort' ] diff --git a/python/taichi/tools/sort.py b/python/taichi/tools/sort.py new file mode 100644 index 0000000000000..6f24788d197b7 --- /dev/null +++ b/python/taichi/tools/sort.py @@ -0,0 +1,29 @@ +import taichi as ti + + +def parallel_sort(x): + N = x.shape[0] + + @ti.kernel + def sort_stage(x: ti.template(), N: int, p: int, k: int, invocations: int): + for inv in range(invocations): + j = k % p + inv * 2 * k + for i in range(0, min(k, N - j - k)): + a = i + j + b = i + j + k + if int(a / (p * 2)) == int(b / (p * 2)): + val_a = x[a] + val_b = x[b] + if val_a > val_b: + x[a] = val_b + x[b] = val_a + + p = 1 + while p < N: + k = p + while k >= 1: + invocations = int((N - k - k % p) / (2 * k)) + 1 + sort_stage(x, N, p, k, invocations) + ti.sync() + k = int(k / 2) + p = int(p * 2) diff --git a/tests/python/test_sort.py b/tests/python/test_sort.py new file mode 100644 index 0000000000000..1d0e64b31e0a6 --- /dev/null +++ b/tests/python/test_sort.py @@ -0,0 +1,27 @@ +import taichi as ti + + +@ti.test() +def test_sort(): + def test_sort_for_dtype(dtype, N): + x = ti.field(dtype, N) + + @ti.kernel + def fill(): + for i in x: + x[i] = ti.random() * N + + fill() + ti.parallel_sort(x) + + x_host = x.to_numpy() + + for i in range(N - 1): + assert x_host[i] <= x_host[i + 1] + + test_sort_for_dtype(ti.i32, 1) + test_sort_for_dtype(ti.i32, 256) + test_sort_for_dtype(ti.i32, 100001) + test_sort_for_dtype(ti.f32, 1) + test_sort_for_dtype(ti.f32, 256) + test_sort_for_dtype(ti.f32, 100001) From 2ec19591a50b04a201c578292713e25be1eff372 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Tue, 14 Dec 2021 00:52:51 +0000 Subject: [PATCH 2/5] add comments --- python/taichi/tools/sort.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/taichi/tools/sort.py b/python/taichi/tools/sort.py index 6f24788d197b7..9ce39c156968c 100644 --- a/python/taichi/tools/sort.py +++ b/python/taichi/tools/sort.py @@ -1,6 +1,10 @@ import taichi as ti +# Odd-even merge sort +# References: +# https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting +# https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort def parallel_sort(x): N = x.shape[0] From cc8534ec54baae49c25ebf19c17cd00d2bab21a3 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Tue, 14 Dec 2021 01:19:24 +0000 Subject: [PATCH 3/5] exclude cc backend --- tests/python/test_sort.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/test_sort.py b/tests/python/test_sort.py index 1d0e64b31e0a6..9d9039dcc83e2 100644 --- a/tests/python/test_sort.py +++ b/tests/python/test_sort.py @@ -1,7 +1,7 @@ import taichi as ti -@ti.test() +@ti.test(exclude=[ti.cc]) def test_sort(): def test_sort_for_dtype(dtype, N): x = ti.field(dtype, N) From c5f5df0ea69055f5efcedc872cc30dbbfb3fc259 Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Tue, 14 Dec 2021 17:36:38 +0000 Subject: [PATCH 4/5] allow sorting values by keys --- python/taichi/tools/sort.py | 31 +++++++++++++++++++++---------- tests/python/test_sort.py | 19 ++++++++++++------- 2 files changed, 33 insertions(+), 17 deletions(-) diff --git a/python/taichi/tools/sort.py b/python/taichi/tools/sort.py index 9ce39c156968c..3999ba05e5a9e 100644 --- a/python/taichi/tools/sort.py +++ b/python/taichi/tools/sort.py @@ -1,3 +1,5 @@ +from taichi.core.util import ti_core as _ti_core + import taichi as ti @@ -5,29 +7,38 @@ # References: # https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting # https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort -def parallel_sort(x): - N = x.shape[0] +def parallel_sort(keys, values=None): + N = keys.shape[0] @ti.kernel - def sort_stage(x: ti.template(), N: int, p: int, k: int, invocations: int): + def sort_stage(keys: ti.template(), use_values: int, values: ti.template(), + N: int, p: int, k: int, invocations: int): for inv in range(invocations): j = k % p + inv * 2 * k for i in range(0, min(k, N - j - k)): a = i + j b = i + j + k if int(a / (p * 2)) == int(b / (p * 2)): - val_a = x[a] - val_b = x[b] - if val_a > val_b: - x[a] = val_b - x[b] = val_a + key_a = keys[a] + key_b = keys[b] + if key_a > key_b: + keys[a] = key_b + keys[b] = key_a + if use_values != 0: + temp = values[a] + values[a] = values[b] + values[b] = temp p = 1 while p < N: k = p while k >= 1: invocations = int((N - k - k % p) / (2 * k)) + 1 - sort_stage(x, N, p, k, invocations) - ti.sync() + if values is None: + sort_stage(keys, 0, keys, N, p, k, invocations) + else: + sort_stage(keys, 1, values, N, p, k, invocations) + if _ti_core.current_compile_config() == ti.vulkan: + ti.sync() k = int(k / 2) p = int(p * 2) diff --git a/tests/python/test_sort.py b/tests/python/test_sort.py index 9d9039dcc83e2..615a60886afbb 100644 --- a/tests/python/test_sort.py +++ b/tests/python/test_sort.py @@ -4,20 +4,25 @@ @ti.test(exclude=[ti.cc]) def test_sort(): def test_sort_for_dtype(dtype, N): - x = ti.field(dtype, N) + keys = ti.field(dtype, N) + values = ti.field(dtype, N) @ti.kernel def fill(): - for i in x: - x[i] = ti.random() * N + for i in keys: + keys[i] = ti.random() * N + values[i] = keys[i] fill() - ti.parallel_sort(x) + ti.parallel_sort(keys, values) - x_host = x.to_numpy() + keys_host = keys.to_numpy() + values_host = values.to_numpy() - for i in range(N - 1): - assert x_host[i] <= x_host[i + 1] + for i in range(N): + if i < N - 1: + assert keys_host[i] <= keys_host[i + 1] + assert keys_host[i] == values_host[i] test_sort_for_dtype(ti.i32, 1) test_sort_for_dtype(ti.i32, 256) From 39d1344330e5faf671ec7a2a0c1b73b926940c1c Mon Sep 17 00:00:00 2001 From: Dunfan Lu Date: Tue, 14 Dec 2021 19:09:22 +0000 Subject: [PATCH 5/5] fix --- python/taichi/tools/sort.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/python/taichi/tools/sort.py b/python/taichi/tools/sort.py index 3999ba05e5a9e..283f9e0b1e5c4 100644 --- a/python/taichi/tools/sort.py +++ b/python/taichi/tools/sort.py @@ -1,5 +1,3 @@ -from taichi.core.util import ti_core as _ti_core - import taichi as ti @@ -38,7 +36,6 @@ def sort_stage(keys: ti.template(), use_values: int, values: ti.template(), sort_stage(keys, 0, keys, N, p, k, invocations) else: sort_stage(keys, 1, values, N, p, k, invocations) - if _ti_core.current_compile_config() == ti.vulkan: - ti.sync() + ti.sync() k = int(k / 2) p = int(p * 2)