diff --git a/.github/actions/docs-build/action.yml b/.github/actions/docs-build/action.yml
index 8b997f4741f..db7f3231742 100644
--- a/.github/actions/docs-build/action.yml
+++ b/.github/actions/docs-build/action.yml
@@ -38,6 +38,8 @@ runs:
cp -rf ./docs/_build/docs/cudax/latest/* _site/cudax
mkdir _site/cuda_cooperative
cp -rf ./docs/_build/docs/cuda_cooperative/latest/* _site/cuda_cooperative
+ mkdir _site/cuda_parallel
+ cp -rf ./docs/_build/docs/cuda_parallel/latest/* _site/cuda_parallel
./docs/scrape_docs.bash ./_site
# Update docs as workflow artifact:
diff --git a/c/include/cccl/types.h b/c/include/cccl/types.h
index 6b19848de4e..99ee71494c0 100644
--- a/c/include/cccl/types.h
+++ b/c/include/cccl/types.h
@@ -44,8 +44,8 @@ struct cccl_type_info
enum class cccl_op_kind_t
{
- stateless,
- stateful
+ stateless = 0,
+ stateful = 1
};
struct cccl_op_t
@@ -61,8 +61,8 @@ struct cccl_op_t
enum class cccl_iterator_kind_t
{
- pointer,
- iterator
+ pointer = 0,
+ iterator = 1
};
struct cccl_value_t
diff --git a/ci/update_version.sh b/ci/update_version.sh
index b1304f37d1c..9184b98e6a9 100755
--- a/ci/update_version.sh
+++ b/ci/update_version.sh
@@ -36,7 +36,8 @@ CUB_CMAKE_VERSION_FILE="cub/cub/cmake/cub-config-version.cmake"
LIBCUDACXX_CMAKE_VERSION_FILE="libcudacxx/lib/cmake/libcudacxx/libcudacxx-config-version.cmake"
THRUST_CMAKE_VERSION_FILE="thrust/thrust/cmake/thrust-config-version.cmake"
CUDAX_CMAKE_VERSION_FILE="cudax/lib/cmake/cudax/cudax-config-version.cmake"
-PYCUDA_VERSION_FILE="python/cuda_cooperative/cuda/cooperative/_version.py"
+CUDA_COOPERATIVE_VERSION_FILE="python/cuda_cooperative/cuda/cooperative/_version.py"
+CUDA_PARALLEL_VERSION_FILE="python/cuda_parallel/cuda/parallel/_version.py"
# Calculated version codes
new_cccl_version=$((major * 1000000 + minor * 1000 + patch)) # MMMmmmppp
@@ -102,7 +103,8 @@ update_file "$CUDAX_CMAKE_VERSION_FILE" "set(cudax_VERSION_MAJOR \([0-9]\+\))" "
update_file "$CUDAX_CMAKE_VERSION_FILE" "set(cudax_VERSION_MINOR \([0-9]\+\))" "set(cudax_VERSION_MINOR $minor)"
update_file "$CUDAX_CMAKE_VERSION_FILE" "set(cudax_VERSION_PATCH \([0-9]\+\))" "set(cudax_VERSION_PATCH $patch)"
-update_file "$PYCUDA_VERSION_FILE" "^__version__ = \"\([0-9.]\+\)\"" "__version__ = \"$pymajor.$pyminor.$major.$minor.$patch\""
+update_file "$CUDA_COOPERATIVE_VERSION_FILE" "^__version__ = \"\([0-9.]\+\)\"" "__version__ = \"$pymajor.$pyminor.$major.$minor.$patch\""
+update_file "$CUDA_PARALLEL_VERSION_FILE" "^__version__ = \"\([0-9.]\+\)\"" "__version__ = \"$pymajor.$pyminor.$major.$minor.$patch\""
if [ "$DRY_RUN" = true ]; then
echo "Dry run completed. No changes made."
diff --git a/docs/cuda_parallel/index.rst b/docs/cuda_parallel/index.rst
new file mode 100644
index 00000000000..5a76a3d35d9
--- /dev/null
+++ b/docs/cuda_parallel/index.rst
@@ -0,0 +1,13 @@
+.. _cuda_parallel-module:
+
+CUDA Parallel
+==================================================
+
+.. warning::
+ Python exposure of parallel algorithms is in public beta.
+ The API is subject to change without notice.
+
+.. automodule:: cuda.parallel.experimental
+ :members:
+ :undoc-members:
+ :imported-members:
diff --git a/docs/python.rst b/docs/python.rst
index b0b9c5b73f9..164fdaf8298 100644
--- a/docs/python.rst
+++ b/docs/python.rst
@@ -8,8 +8,12 @@ CUDA Python Core Libraries
:maxdepth: 3
cuda.cooperative
+ cuda.parallel
Welcome to the CUDA Core Compute Libraries (CCCL) libraries for Python.
- `cuda.cooperative `__
is a still-experimental library exposing cooperative algorithms to Python.
+
+- `cuda.parallel `__
+ is a still-experimental library exposing parallel algorithms to Python.
diff --git a/docs/repo.toml b/docs/repo.toml
index 8a68125bea4..7cadfce6aa1 100644
--- a/docs/repo.toml
+++ b/docs/repo.toml
@@ -25,7 +25,7 @@ sphinx_exclude_patterns = [
"VERSION.md",
]
-project_build_order = [ "libcudacxx", "cudax", "cub", "thrust", "cccl", "cuda_cooperative" ]
+project_build_order = [ "libcudacxx", "cudax", "cub", "thrust", "cccl", "cuda_cooperative", "cuda_parallel" ]
# deps can be used to link to other projects' documentation
deps = [
@@ -33,6 +33,7 @@ deps = [
[ "cub", "_build/docs/cub/latest" ],
[ "thrust", "_build/docs/thrust/latest" ],
[ "cuda_cooperative", "_build/docs/cuda_cooperative/latest" ],
+ [ "cuda_parallel", "_build/docs/cuda_parallel/latest" ],
]
[repo_docs.projects.libcudacxx]
@@ -304,6 +305,29 @@ python_paths = [
"${root}/../python/cuda_cooperative"
]
+[repo_docs.projects.cuda_parallel]
+name = "cuda.parallel"
+docs_root = "cuda_parallel"
+logo = "../img/logo.png"
+
+repo_url = "https://github.com/NVIDIA/cccl/python/cuda"
+social_media_set = ""
+social_media = [
+ [ "github", "https://github.com/NVIDIA/cccl" ],
+]
+
+autodoc.mock_imports = [
+ "numba",
+ "pynvjitlink",
+ "cuda.nvrtc",
+ "llvmlite"
+]
+
+enhanced_search_enabled = true
+python_paths = [
+ "${root}/../python/cuda_parallel"
+]
+
[repo_docs.projects.cudax]
name = "Cudax: Experimental new features"
docs_root = "cudax"
diff --git a/python/cuda_parallel/.gitignore b/python/cuda_parallel/.gitignore
new file mode 100644
index 00000000000..8e0d030ff6a
--- /dev/null
+++ b/python/cuda_parallel/.gitignore
@@ -0,0 +1,4 @@
+cuda/_include
+env
+*egg-info
+*so
diff --git a/python/cuda_parallel/MANIFEST.in b/python/cuda_parallel/MANIFEST.in
new file mode 100644
index 00000000000..848cbfe2e81
--- /dev/null
+++ b/python/cuda_parallel/MANIFEST.in
@@ -0,0 +1 @@
+recursive-include cuda/_include *
diff --git a/python/cuda_parallel/README.md b/python/cuda_parallel/README.md
new file mode 100644
index 00000000000..98a3a3c92d0
--- /dev/null
+++ b/python/cuda_parallel/README.md
@@ -0,0 +1,12 @@
+# `cuda.parallel`: Experimental CUDA Core Compute Library for Python
+
+## Documentation
+
+Please visit the documentation here: https://nvidia.github.io/cccl/python.html.
+
+## Local development
+
+```bash
+pip3 install -e .[test]
+pytest -v ./tests/
+```
diff --git a/python/cuda_parallel/cuda/parallel/__init__.py b/python/cuda_parallel/cuda/parallel/__init__.py
new file mode 100644
index 00000000000..6bb31cc9b46
--- /dev/null
+++ b/python/cuda_parallel/cuda/parallel/__init__.py
@@ -0,0 +1,6 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import cuda.parallel.experimental
+from cuda.parallel._version import __version__
diff --git a/python/cuda_parallel/cuda/parallel/_version.py b/python/cuda_parallel/cuda/parallel/_version.py
new file mode 100644
index 00000000000..aaab7ef4ad5
--- /dev/null
+++ b/python/cuda_parallel/cuda/parallel/_version.py
@@ -0,0 +1,7 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+# This file is generated by ci/update_version.sh
+# Do not edit this file manually.
+__version__ = "0.1.2.6.0"
diff --git a/python/cuda_parallel/cuda/parallel/experimental/__init__.py b/python/cuda_parallel/cuda/parallel/experimental/__init__.py
new file mode 100644
index 00000000000..4a16fc1b67a
--- /dev/null
+++ b/python/cuda_parallel/cuda/parallel/experimental/__init__.py
@@ -0,0 +1,278 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import importlib
+import ctypes
+import shutil
+import numba
+import os
+
+from numba import cuda, types
+from numba.cuda.cudadrv import enums
+
+
+# Should match C++
+class _TypeEnum(ctypes.c_int):
+ INT8 = 0
+ INT16 = 1
+ INT32 = 2
+ INT64 = 3
+ UINT8 = 4
+ UINT16 = 5
+ UINT32 = 6
+ UINT64 = 7
+ FLOAT32 = 8
+ FLOAT64 = 9
+ STORAGE = 10
+
+
+# Should match C++
+class _CCCLOpKindEnum(ctypes.c_int):
+ STATELESS = 0
+ STATEFUL = 1
+
+
+# Should match C++
+class _CCCLIteratorKindEnum(ctypes.c_int):
+ POINTER = 0
+ ITERATOR = 1
+
+
+def _type_to_enum(numba_type):
+ mapping = {
+ types.int8: _TypeEnum.INT8,
+ types.int16: _TypeEnum.INT16,
+ types.int32: _TypeEnum.INT32,
+ types.int64: _TypeEnum.INT64,
+ types.uint8: _TypeEnum.UINT8,
+ types.uint16: _TypeEnum.UINT16,
+ types.uint32: _TypeEnum.UINT32,
+ types.uint64: _TypeEnum.UINT64,
+ types.float32: _TypeEnum.FLOAT32,
+ types.float64: _TypeEnum.FLOAT64,
+ }
+ if numba_type in mapping:
+ return mapping[numba_type]
+ return _TypeEnum.STORAGE
+
+
+# TODO Extract into reusable module
+class _TypeInfo(ctypes.Structure):
+ _fields_ = [("size", ctypes.c_int),
+ ("alignment", ctypes.c_int),
+ ("type", _TypeEnum)]
+
+
+class _CCCLOp(ctypes.Structure):
+ _fields_ = [("type", _CCCLOpKindEnum),
+ ("name", ctypes.c_char_p),
+ ("ltoir", ctypes.c_char_p),
+ ("ltoir_size", ctypes.c_int),
+ ("size", ctypes.c_int),
+ ("alignment", ctypes.c_int),
+ ("state", ctypes.c_void_p)]
+
+
+class _CCCLIterator(ctypes.Structure):
+ _fields_ = [("size", ctypes.c_int),
+ ("alignment", ctypes.c_int),
+ ("type", _CCCLIteratorKindEnum),
+ ("advance", _CCCLOp),
+ ("dereference", _CCCLOp),
+ ("value_type", _TypeInfo),
+ ("state", ctypes.c_void_p)]
+
+
+class _CCCLValue(ctypes.Structure):
+ _fields_ = [("type", _TypeInfo),
+ ("state", ctypes.c_void_p)]
+
+
+def _type_to_info(numpy_type):
+ numba_type = numba.from_dtype(numpy_type)
+ context = cuda.descriptor.cuda_target.target_context
+ size = context.get_value_type(numba_type).get_abi_size(context.target_data)
+ alignment = context.get_value_type(
+ numba_type).get_abi_alignment(context.target_data)
+ return _TypeInfo(size, alignment, _type_to_enum(numba_type))
+
+
+def _device_array_to_pointer(array):
+ dtype = array.dtype
+ info = _type_to_info(dtype)
+ return _CCCLIterator(1, 1, _CCCLIteratorKindEnum.POINTER, _CCCLOp(), _CCCLOp(), info, array.device_ctypes_pointer.value)
+
+
+def _host_array_to_value(array):
+ dtype = array.dtype
+ info = _type_to_info(dtype)
+ return _CCCLValue(info, array.ctypes.data_as(ctypes.c_void_p))
+
+
+class _Op:
+ def __init__(self, dtype, op):
+ value_type = numba.from_dtype(dtype)
+ self.ltoir, _ = cuda.compile(op, sig=value_type(
+ value_type, value_type), output='ltoir')
+ self.name = op.__name__.encode('utf-8')
+
+ def handle(self):
+ return _CCCLOp(_CCCLOpKindEnum.STATELESS, self.name, ctypes.c_char_p(self.ltoir), len(self.ltoir), 1, 1, None)
+
+
+def _get_cuda_path():
+ cuda_path = os.environ.get('CUDA_PATH', '')
+ if os.path.exists(cuda_path):
+ return cuda_path
+
+ nvcc_path = shutil.which('nvcc')
+ if nvcc_path is not None:
+ return os.path.dirname(os.path.dirname(nvcc_path))
+
+ default_path = '/usr/local/cuda'
+ if os.path.exists(default_path):
+ return default_path
+
+ return None
+
+
+_bindings = None
+_paths = None
+
+
+def _get_bindings():
+ global _bindings
+ if _bindings is None:
+ include_path = importlib.resources.files(
+ 'cuda.parallel.experimental').joinpath('cccl')
+ cccl_c_path = os.path.join(include_path, 'libcccl.c.so')
+ _bindings = ctypes.CDLL(cccl_c_path)
+ _bindings.cccl_device_reduce.restype = ctypes.c_int
+ _bindings.cccl_device_reduce.restype = ctypes.c_int
+ _bindings.cccl_device_reduce.argtypes = [_CCCLDeviceReduceBuildResult, ctypes.c_void_p, ctypes.POINTER(
+ ctypes.c_ulonglong), _CCCLIterator, _CCCLIterator, ctypes.c_ulonglong, _CCCLOp, _CCCLValue, ctypes.c_void_p]
+ _bindings.cccl_device_reduce_cleanup.restype = ctypes.c_int
+ return _bindings
+
+
+def _get_paths():
+ global _paths
+ if _paths is None:
+ include_path = importlib.resources.files('cuda').joinpath('_include')
+ include_path_str = str(include_path)
+ include_option = '-I' + include_path_str
+ cub_path = include_option.encode('utf-8')
+ thrust_path = cub_path
+ libcudacxx_path_str = str(os.path.join(include_path, 'libcudacxx'))
+ libcudacxx_option = '-I' + libcudacxx_path_str
+ libcudacxx_path = libcudacxx_option.encode('utf-8')
+ cuda_include_str = os.path.join(_get_cuda_path(), 'include')
+ cuda_include_option = '-I' + cuda_include_str
+ cuda_include_path = cuda_include_option.encode('utf-8')
+ _paths = cub_path, thrust_path, libcudacxx_path, cuda_include_path
+ return _paths
+
+
+class _CCCLDeviceReduceBuildResult(ctypes.Structure):
+ _fields_ = [("cc", ctypes.c_int),
+ ("cubin", ctypes.c_void_p),
+ ("cubin_size", ctypes.c_size_t),
+ ("library", ctypes.c_void_p),
+ ("single_tile_kernel", ctypes.c_void_p),
+ ("single_tile_second_kernel", ctypes.c_void_p),
+ ("reduction_kernel", ctypes.c_void_p)]
+
+
+class _Reduce:
+ def __init__(self, d_in, d_out, op, init):
+ cc_major, cc_minor = cuda.get_current_device().compute_capability
+ cub_path, thrust_path, libcudacxx_path, cuda_include_path = _get_paths()
+ bindings = _get_bindings()
+ accum_t = init.dtype
+ self.op_wrapper = _Op(accum_t, op)
+ d_in_ptr = _device_array_to_pointer(d_in)
+ d_out_ptr = _device_array_to_pointer(d_out)
+ self.build_result = _CCCLDeviceReduceBuildResult()
+
+ # TODO Figure out caching
+ error = bindings.cccl_device_reduce_build(ctypes.byref(self.build_result),
+ d_in_ptr,
+ d_out_ptr,
+ self.op_wrapper.handle(),
+ _host_array_to_value(init),
+ cc_major,
+ cc_minor,
+ ctypes.c_char_p(cub_path),
+ ctypes.c_char_p(thrust_path),
+ ctypes.c_char_p(
+ libcudacxx_path),
+ ctypes.c_char_p(cuda_include_path))
+ if error != enums.CUDA_SUCCESS:
+ raise ValueError('Error building reduce')
+
+ def __call__(self, temp_storage, d_in, d_out, init):
+ # TODO Assert that types match the ones used in the constructor
+ bindings = _get_bindings()
+ if temp_storage is None:
+ temp_storage_bytes = ctypes.c_size_t()
+ d_temp_storage = None
+ else:
+ temp_storage_bytes = ctypes.c_size_t(temp_storage.nbytes)
+ d_temp_storage = temp_storage.device_ctypes_pointer.value
+ d_in_ptr = _device_array_to_pointer(d_in)
+ d_out_ptr = _device_array_to_pointer(d_out)
+ num_items = ctypes.c_ulonglong(d_in.size)
+ error = bindings.cccl_device_reduce(self.build_result,
+ d_temp_storage,
+ ctypes.byref(temp_storage_bytes),
+ d_in_ptr,
+ d_out_ptr,
+ num_items,
+ self.op_wrapper.handle(),
+ _host_array_to_value(init),
+ None)
+ if error != enums.CUDA_SUCCESS:
+ raise ValueError('Error reducing')
+
+ return temp_storage_bytes.value
+
+ def __del__(self):
+ bindings = _get_bindings()
+ bindings.cccl_device_reduce_cleanup(ctypes.byref(self.build_result))
+
+
+# TODO Figure out iterators
+# TODO Figure out `sum` without operator and initial value
+# TODO Accept stream
+def reduce_into(d_in, d_out, op, init):
+ """Computes a device-wide reduction using the specified binary ``op`` functor and initial value ``init``.
+
+ Example:
+ The code snippet below illustrates a user-defined min-reduction of a
+ device vector of ``int`` data elements.
+
+ .. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
+ :language: python
+ :dedent:
+ :start-after: example-begin imports
+ :end-before: example-end imports
+
+ Below is the code snippet that demonstrates the usage of the ``reduce_into`` API:
+
+ .. literalinclude:: ../../python/cuda_parallel/tests/test_reduce_api.py
+ :language: python
+ :dedent:
+ :start-after: example-begin reduce-min
+ :end-before: example-end reduce-min
+
+ Args:
+ d_in: CUDA device array storing the input sequence of data items
+ d_out: CUDA device array storing the output aggregate
+ op: Binary reduction
+ init: Numpy array storing initial value of the reduction
+
+ Returns:
+ A callable object that can be used to perform the reduction
+ """
+ return _Reduce(d_in, d_out, op, init)
diff --git a/python/cuda_parallel/pyproject.toml b/python/cuda_parallel/pyproject.toml
new file mode 100644
index 00000000000..4ab52c80318
--- /dev/null
+++ b/python/cuda_parallel/pyproject.toml
@@ -0,0 +1,7 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+[build-system]
+requires = ["packaging", "setuptools>=61.0.0", "wheel"]
+build-backend = "setuptools.build_meta"
diff --git a/python/cuda_parallel/setup.py b/python/cuda_parallel/setup.py
new file mode 100644
index 00000000000..c29a5237fc0
--- /dev/null
+++ b/python/cuda_parallel/setup.py
@@ -0,0 +1,129 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import os
+import shutil
+import subprocess
+
+from setuptools import Command, Extension, setup, find_packages, find_namespace_packages
+from setuptools.command.build_py import build_py
+from setuptools.command.build_ext import build_ext
+from wheel.bdist_wheel import bdist_wheel
+
+
+project_path = os.path.abspath(os.path.dirname(__file__))
+cccl_path = os.path.abspath(os.path.join(project_path, "..", '..'))
+cccl_headers = [
+ ['cub', 'cub'],
+ ['libcudacxx', 'include'],
+ ['thrust', 'thrust']
+]
+with open(os.path.join(project_path, 'cuda', 'parallel', '_version.py')) as f:
+ exec(f.read())
+ver = __version__
+del __version__
+
+
+with open("README.md") as f:
+ long_description = f.read()
+
+
+class CustomBuildCommand(build_py):
+ def run(self):
+ self.run_command('package_cccl')
+ build_py.run(self)
+
+
+class CustomWheelBuild(bdist_wheel):
+
+ def run(self):
+ self.run_command('package_cccl')
+ super().run()
+
+
+class PackageCCCLCommand(Command):
+ description = 'Generate additional files'
+ user_options = []
+
+ def initialize_options(self):
+ pass
+
+ def finalize_options(self):
+ pass
+
+ def run(self):
+ for proj_dir, header_dir in cccl_headers:
+ src_path = os.path.abspath(
+ os.path.join(cccl_path, proj_dir, header_dir))
+ # TODO Extract cccl headers into a standalone package
+ dst_path = os.path.join(project_path, 'cuda', '_include', proj_dir)
+ if os.path.exists(dst_path):
+ shutil.rmtree(dst_path)
+ shutil.copytree(src_path, dst_path)
+
+
+class CMakeExtension(Extension):
+ def __init__(self, name):
+ super().__init__(name, sources=[])
+
+
+class BuildCMakeExtension(build_ext):
+ def run(self):
+ for ext in self.extensions:
+ self.build_extension(ext)
+
+ def build_extension(self, ext):
+ extdir = os.path.abspath(os.path.dirname(
+ self.get_ext_fullpath(ext.name)))
+ cmake_args = [
+ '-DCCCL_ENABLE_CUB=YES',
+ '-DCCCL_ENABLE_THRUST=YES',
+ '-DCCCL_ENABLE_C=YES',
+ '-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir,
+ '-DCMAKE_BUILD_TYPE=Release',
+ ]
+
+ if not os.path.exists(self.build_temp):
+ os.makedirs(self.build_temp)
+
+ subprocess.check_call(['cmake', cccl_path] +
+ cmake_args, cwd=self.build_temp)
+ subprocess.check_call(
+ ['cmake', '--build', '.', '--target', 'cccl.c'], cwd=self.build_temp)
+
+
+setup(
+ name="cuda-parallel",
+ version=ver,
+ description="Experimental Core Library for CUDA Python",
+ long_description=long_description,
+ long_description_content_type="text/markdown",
+ author="NVIDIA Corporation",
+ classifiers=[
+ "Programming Language :: Python :: 3 :: Only",
+ "Environment :: GPU :: NVIDIA CUDA",
+ ],
+ packages=find_namespace_packages(include=['cuda.*']),
+ python_requires='>=3.9',
+ install_requires=[
+ "numba>=0.60.0",
+ "cuda-python",
+ "jinja2"
+ ],
+ extras_require={
+ "test": [
+ "pytest",
+ ]
+ },
+ cmdclass={
+ 'package_cccl': PackageCCCLCommand,
+ 'build_py': CustomBuildCommand,
+ 'bdist_wheel': CustomWheelBuild,
+ 'build_ext': BuildCMakeExtension
+ },
+ ext_modules=[CMakeExtension('cuda.parallel.experimental.cccl.c')],
+ include_package_data=True,
+ license="Apache-2.0 with LLVM exception",
+ license_files=('../../LICENSE',),
+)
diff --git a/python/cuda_parallel/tests/test_reduce.py b/python/cuda_parallel/tests/test_reduce.py
new file mode 100644
index 00000000000..9f59f8efcec
--- /dev/null
+++ b/python/cuda_parallel/tests/test_reduce.py
@@ -0,0 +1,68 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import numpy
+import pytest
+from numba import cuda
+import cuda.parallel.experimental as cudax
+
+
+def random_int(shape, dtype):
+ return numpy.random.randint(0, 5, size=shape).astype(dtype)
+
+
+def type_to_problem_sizes(dtype):
+ if dtype in [numpy.uint8, numpy.int8]:
+ return [2, 4, 5, 6]
+ elif dtype in [numpy.uint16, numpy.int16]:
+ return [4, 8, 12, 14]
+ elif dtype in [numpy.uint32, numpy.int32]:
+ return [16, 20, 24, 28]
+ elif dtype in [numpy.uint64, numpy.int64]:
+ return [16, 20, 24, 28]
+ else:
+ raise ValueError("Unsupported dtype")
+
+
+@pytest.mark.parametrize('dtype', [numpy.uint8, numpy.uint16, numpy.uint32, numpy.uint64])
+def test_device_reduce(dtype):
+ def op(a, b):
+ return a + b
+
+ init_value = 42
+ h_init = numpy.array([init_value], dtype=dtype)
+ d_output = cuda.device_array(1, dtype=dtype)
+ reduce_into = cudax.reduce_into(d_output, d_output, op, h_init)
+
+ for num_items_pow2 in type_to_problem_sizes(dtype):
+ num_items = 2 ** num_items_pow2
+ h_input = random_int(num_items, dtype)
+ d_input = cuda.to_device(h_input)
+ temp_storage_size = reduce_into(None, d_input, d_output, h_init)
+ d_temp_storage = cuda.device_array(
+ temp_storage_size, dtype=numpy.uint8)
+ reduce_into(d_temp_storage, d_input, d_output, h_init)
+ h_output = d_output.copy_to_host()
+ assert h_output[0] == sum(h_input) + init_value
+
+
+def test_complex_device_reduce():
+ def op(a, b):
+ return a + b
+
+ h_init = numpy.array([40.0 + 2.0j], dtype=complex)
+ d_output = cuda.device_array(1, dtype=complex)
+ reduce_into = cudax.reduce_into(d_output, d_output, op, h_init)
+
+ for num_items in [42, 420000]:
+ h_input = numpy.random.random(
+ num_items) + 1j * numpy.random.random(num_items)
+ d_input = cuda.to_device(h_input)
+ temp_storage_bytes = reduce_into(None, d_input, d_output, h_init)
+ d_temp_storage = cuda.device_array(temp_storage_bytes, numpy.uint8)
+ reduce_into(d_temp_storage, d_input, d_output, h_init)
+
+ result = d_output.copy_to_host()[0]
+ expected = numpy.sum(h_input, initial=h_init[0])
+ assert result == pytest.approx(expected)
diff --git a/python/cuda_parallel/tests/test_reduce_api.py b/python/cuda_parallel/tests/test_reduce_api.py
new file mode 100644
index 00000000000..6ed35831218
--- /dev/null
+++ b/python/cuda_parallel/tests/test_reduce_api.py
@@ -0,0 +1,40 @@
+# Copyright (c) 2024, NVIDIA CORPORATION & AFFILIATES. ALL RIGHTS RESERVED.
+#
+# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+
+import numpy
+import pytest
+from numba import cuda
+
+# example-begin imports
+import cuda.parallel.experimental as cudax
+# example-end imports
+
+
+def test_device_reduce():
+ # example-begin reduce-min
+ def op(a, b):
+ return a if a < b else b
+
+ dtype = numpy.int32
+ h_init = numpy.array([42], dtype)
+ h_input = numpy.array([8, 6, 7, 5, 3, 0, 9])
+ d_output = cuda.device_array(1, dtype)
+ d_input = cuda.to_device(h_input)
+
+ # Instantiate reduction for the given operator and initial value
+ reduce_into = cudax.reduce_into(d_output, d_output, op, h_init)
+
+ # Deterrmine temporary device storage requirements
+ temp_storage_size = reduce_into(None, d_input, d_output, h_init)
+
+ # Allocate temporary storage
+ d_temp_storage = cuda.device_array(temp_storage_size, dtype=numpy.uint8)
+
+ # Run reduction
+ reduce_into(d_temp_storage, d_input, d_output, h_init)
+
+ expected_output = 0
+ # example-end reduce-min
+ h_output = d_output.copy_to_host()
+ assert h_output[0] == expected_output