Skip to content

Commit

Permalink
[Misc] Add prefix sum executor to avoid multiple field allocations (#…
Browse files Browse the repository at this point in the history
…6132)

1. Add `PrefixSumExecutor` helper to avoid every prefix sum function
call's memory allocation.
2. Move `parallel_sort` and `prefix_sum` utilities from `_kernels.py` to
`algorithms.py`.

Maybe we can also integrate @yuanming-hu 's CPU prefix sum into
`PrefixSumExecutor` in the future.

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ailing  <[email protected]>
  • Loading branch information
3 people authored Sep 23, 2022
1 parent 8e9d978 commit 507d025
Show file tree
Hide file tree
Showing 7 changed files with 119 additions and 84 deletions.
2 changes: 1 addition & 1 deletion python/taichi/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Provide a shortcut to types since they're commonly used.
from taichi.types.primitive_types import *

from taichi import ad, experimental, graph, linalg, math, tools
from taichi import ad, algorithms, experimental, graph, linalg, math, tools
from taichi.ui import GUI, hex_to_rgb, rgb_to_hex, ui

# Issue#2223: Do not reorder, or we're busted with partially initialized module
Expand Down
81 changes: 4 additions & 77 deletions python/taichi/_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,10 @@
from taichi.lang._ndrange import ndrange
from taichi.lang.expr import Expr
from taichi.lang.field import ScalarField
from taichi.lang.impl import current_cfg, field, grouped, static, static_assert
from taichi.lang.impl import grouped, static, static_assert
from taichi.lang.kernel_impl import func, kernel
from taichi.lang.misc import cuda, loop_config, vulkan
from taichi.lang.runtime_ops import sync
from taichi.lang.simt import block, subgroup, warp
from taichi.lang.misc import loop_config
from taichi.lang.simt import block, warp
from taichi.lang.snode import deactivate
from taichi.types import ndarray_type, texture_type, vector
from taichi.types.annotations import template
Expand Down Expand Up @@ -292,9 +291,6 @@ def save_texture_to_numpy(tex: texture_type.rw_texture(num_dimensions=2,


# Odd-even merge sort
# References:
# https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
# https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
@kernel
def sort_stage(keys: template(), use_values: int, values: template(), N: int,
p: int, k: int, invocations: int):
Expand All @@ -315,26 +311,7 @@ def sort_stage(keys: template(), use_values: int, values: template(), N: int,
values[b] = temp


def parallel_sort(keys, values=None):
N = keys.shape[0]

num_stages = 0
p = 1
while p < N:
k = p
while k >= 1:
invocations = int((N - k - k % p) / (2 * k)) + 1
if values is None:
sort_stage(keys, 0, keys, N, p, k, invocations)
else:
sort_stage(keys, 1, values, N, p, k, invocations)
num_stages += 1
sync()
k = int(k / 2)
p = int(p * 2)
print(num_stages)


# Parallel Prefix Sum (Scan)
@func
def warp_shfl_up_i32(val: template()):
global_tid = block.global_thread_idx()
Expand Down Expand Up @@ -421,53 +398,3 @@ def blit_from_field_to_field(
dst: template(), src: template(), offset: i32, size: i32):
for i in range(size):
dst[i + offset] = src[i]


# Parallel Prefix Sum (Scan)
# Ref[0]: https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
# Ref[1]: https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
def prefix_sum_inclusive_inplace(input_arr, length):
BLOCK_SZ = 64
GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)

# Buffer position and length
# This is a single buffer implementation for ease of aot usage
ele_num = length
ele_nums = [ele_num]
start_pos = 0
ele_nums_pos = [start_pos]

while ele_num > 1:
ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
ele_nums.append(ele_num)
start_pos += BLOCK_SZ * ele_num
ele_nums_pos.append(start_pos)

if input_arr.dtype != i32:
raise RuntimeError("Only ti.i32 type is supported for prefix sum.")

large_arr = field(i32, shape=start_pos)

if current_cfg().arch == cuda:
inclusive_add = warp_shfl_up_i32
elif current_cfg().arch == vulkan:
inclusive_add = subgroup.inclusive_add
else:
raise RuntimeError(
f"{str(current_cfg().arch)} is not supported for prefix sum.")

blit_from_field_to_field(large_arr, input_arr, 0, length)

# Kogge-Stone construction
for i in range(len(ele_nums) - 1):
if i == len(ele_nums) - 2:
scan_add_inclusive(large_arr, ele_nums_pos[i], ele_nums_pos[i + 1],
True, inclusive_add)
else:
scan_add_inclusive(large_arr, ele_nums_pos[i], ele_nums_pos[i + 1],
False, inclusive_add)

for i in range(len(ele_nums) - 3, -1, -1):
uniform_add(large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])

blit_from_field_to_field(input_arr, large_arr, 0, length)
1 change: 1 addition & 0 deletions python/taichi/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from ._algorithms import *
101 changes: 101 additions & 0 deletions python/taichi/algorithms/_algorithms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from taichi._kernels import (blit_from_field_to_field, scan_add_inclusive,
sort_stage, uniform_add, warp_shfl_up_i32)
from taichi.lang.impl import current_cfg, field
from taichi.lang.kernel_impl import data_oriented
from taichi.lang.misc import cuda, vulkan
from taichi.lang.runtime_ops import sync
from taichi.lang.simt import subgroup
from taichi.types.primitive_types import i32


def parallel_sort(keys, values=None):
"""Odd-even merge sort
References:
https://developer.nvidia.com/gpugems/gpugems2/part-vi-simulation-and-numerical-algorithms/chapter-46-improved-gpu-sorting
https://en.wikipedia.org/wiki/Batcher_odd%E2%80%93even_mergesort
"""
N = keys.shape[0]

num_stages = 0
p = 1
while p < N:
k = p
while k >= 1:
invocations = int((N - k - k % p) / (2 * k)) + 1
if values is None:
sort_stage(keys, 0, keys, N, p, k, invocations)
else:
sort_stage(keys, 1, values, N, p, k, invocations)
num_stages += 1
sync()
k = int(k / 2)
p = int(p * 2)
print(num_stages)


@data_oriented
class PrefixSumExecutor:
"""Parallel Prefix Sum (Scan) Helper
Use this helper to perform an inclusive in-place's parallel prefix sum.
References:
https://developer.download.nvidia.com/compute/cuda/1.1-Beta/x86_website/projects/scan/doc/scan.pdf
https://github.com/NVIDIA/cuda-samples/blob/master/Samples/2_Concepts_and_Techniques/shfl_scan/shfl_scan.cu
"""
def __init__(self, length):
self.sorting_length = length

BLOCK_SZ = 64
GRID_SZ = int((length + BLOCK_SZ - 1) / BLOCK_SZ)

# Buffer position and length
# This is a single buffer implementation for ease of aot usage
ele_num = length
self.ele_nums = [ele_num]
start_pos = 0
self.ele_nums_pos = [start_pos]

while ele_num > 1:
ele_num = int((ele_num + BLOCK_SZ - 1) / BLOCK_SZ)
self.ele_nums.append(ele_num)
start_pos += BLOCK_SZ * ele_num
self.ele_nums_pos.append(start_pos)

self.large_arr = field(i32, shape=start_pos)

def run(self, input_arr):
length = self.sorting_length
ele_nums = self.ele_nums
ele_nums_pos = self.ele_nums_pos

if input_arr.dtype != i32:
raise RuntimeError("Only ti.i32 type is supported for prefix sum.")

if current_cfg().arch == cuda:
inclusive_add = warp_shfl_up_i32
elif current_cfg().arch == vulkan:
inclusive_add = subgroup.inclusive_add
else:
raise RuntimeError(
f"{str(current_cfg().arch)} is not supported for prefix sum.")

blit_from_field_to_field(self.large_arr, input_arr, 0, length)

# Kogge-Stone construction
for i in range(len(ele_nums) - 1):
if i == len(ele_nums) - 2:
scan_add_inclusive(self.large_arr, ele_nums_pos[i],
ele_nums_pos[i + 1], True, inclusive_add)
else:
scan_add_inclusive(self.large_arr, ele_nums_pos[i],
ele_nums_pos[i + 1], False, inclusive_add)

for i in range(len(ele_nums) - 3, -1, -1):
uniform_add(self.large_arr, ele_nums_pos[i], ele_nums_pos[i + 1])

blit_from_field_to_field(input_arr, self.large_arr, 0, length)


__all__ = ['parallel_sort', 'PrefixSumExecutor']
9 changes: 5 additions & 4 deletions tests/python/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ def _get_expected_matrix_apis():
'StructField', 'TRACE', 'TaichiAssertionError', 'TaichiCompilationError',
'TaichiNameError', 'TaichiRuntimeError', 'TaichiRuntimeTypeError',
'TaichiSyntaxError', 'TaichiTypeError', 'TetMesh', 'Texture', 'TriMesh',
'Vector', 'VectorNdarray', 'WARN', 'abs', 'acos', 'activate', 'ad', 'aot',
'append', 'arm64', 'asin', 'assume_in_range', 'atan2', 'atomic_add',
'atomic_and', 'atomic_max', 'atomic_min', 'atomic_or', 'atomic_sub',
'atomic_xor', 'axes', 'bit_cast', 'bit_shr', 'block_local',
'Vector', 'VectorNdarray', 'WARN', 'abs', 'acos', 'activate', 'ad',
'algorithms', 'aot', 'append', 'arm64', 'asin', 'assume_in_range', 'atan2',
'atomic_add', 'atomic_and', 'atomic_max', 'atomic_min', 'atomic_or',
'atomic_sub', 'atomic_xor', 'axes', 'bit_cast', 'bit_shr', 'block_local',
'cache_read_only', 'cast', 'cc', 'ceil', 'cos', 'cpu', 'cuda',
'data_oriented', 'dataclass', 'deactivate', 'deactivate_all_snodes',
'dx11', 'eig', 'exp', 'experimental', 'extension', 'f16', 'f32', 'f64',
Expand All @@ -93,6 +93,7 @@ def _get_expected_matrix_apis():
'FwdMode', 'Tape', 'clear_all_gradients', 'grad_for', 'grad_replaced',
'no_grad'
]
user_api[ti.algorithms] = ['PrefixSumExecutor', 'parallel_sort']
user_api[ti.Field] = [
'copy_from', 'dtype', 'fill', 'from_numpy', 'from_paddle', 'from_torch',
'parent', 'shape', 'snode', 'to_numpy', 'to_paddle', 'to_torch'
Expand Down
7 changes: 6 additions & 1 deletion tests/python/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,12 @@ def fill():
arr_aux[i] = arr[i]

fill()
ti._kernels.prefix_sum_inclusive_inplace(arr, N)

# Performing an inclusive in-place's parallel prefix sum,
# only one exectutor is needed for a specified sorting length.
executor = ti.algorithms.PrefixSumExecutor(N)

executor.run(arr)

cur_sum = 0
for i in range(N):
Expand Down
2 changes: 1 addition & 1 deletion tests/python/test_sort.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def fill():
values[i] = keys[i]

fill()
ti._kernels.parallel_sort(keys, values)
ti.algorithms.parallel_sort(keys, values)

keys_host = keys.to_numpy()
values_host = values.to_numpy()
Expand Down

0 comments on commit 507d025

Please sign in to comment.