diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 41b6a639d8..32aab5656b 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -52,6 +52,7 @@ jobs: branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} + skip_upload_pkgs: libraft-template docs-build: if: github.ref_type == 'branch' && github.event_name == 'push' needs: python-build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7606914589..630b8788f8 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -101,7 +101,7 @@ repos: args: ["--toml", "pyproject.toml"] exclude: (?x)^(^CHANGELOG.md$) - repo: https://github.com/rapidsai/dependency-file-generator - rev: v1.4.0 + rev: v1.5.1 hooks: - id: rapids-dependency-file-generator args: ["--clean"] diff --git a/README.md b/README.md index 8519ebcae1..81973ff82e 100755 --- a/README.md +++ b/README.md @@ -1,6 +1,5 @@ #
 RAFT: Reusable Accelerated Functions and Tools
- ![Navigating the canyons of accelerated possibilities](img/raft.png) ## Resources @@ -35,12 +34,16 @@ While not exhaustive, the following general categories help summarize the accele | **Tools & Utilities** | common utilities for developing CUDA applications, multi-node multi-gpu infrastructure | -All of RAFT's C++ APIs can be accessed header-only and optional pre-compiled shared libraries can 1) speed up compile times and 2) enable the APIs to be used without CUDA-enabled compilers. +RAFT is a C++ header-only template library with an optional shared library that +1) can speed up compile times for common template types, and +2) provides host-accessible "runtime" APIs, which don't require a CUDA compiler to use -In addition to the C++ library, RAFT also provides 2 Python libraries: -- `pylibraft` - lightweight low-level Python wrappers around RAFT's host-accessible "runtime" APIs. +In addition being a C++ library, RAFT also provides 2 Python libraries: +- `pylibraft` - lightweight Python wrappers around RAFT's host-accessible "runtime" APIs. - `raft-dask` - multi-node multi-GPU communicator infrastructure for building distributed algorithms on the GPU with Dask. +![RAFT is a C++ header-only template library with optional shared library and lightweight Python wrappers](img/arch.png) + ## Getting started ### RAPIDS Memory Manager (RMM) @@ -81,9 +84,9 @@ raft::device_resources handle; int n_samples = 5000; int n_features = 50; -auto input = raft::make_device_matrix(handle, n_samples, n_features); -auto labels = raft::make_device_vector(handle, n_samples); -auto output = raft::make_device_matrix(handle, n_samples, n_samples); +auto input = raft::make_device_matrix(handle, n_samples, n_features); +auto labels = raft::make_device_vector(handle, n_samples); +auto output = raft::make_device_matrix(handle, n_samples, n_samples); raft::random::make_blobs(handle, input.view(), labels.view()); @@ -218,52 +221,15 @@ pip install raft-dask-cu11 --extra-index-url=https://pypi.ngc.nvidia.com ### CMake & CPM -RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library, which makes it simple to include in downstream cmake projects. RAPIDS CMake provides a convenience layer around CPM. - -After [installing](https://github.com/rapidsai/rapids-cmake#installation) rapids-cmake in your project, you can begin using RAFT by placing the code snippet below in a file named `get_raft.cmake` and including it in your cmake build with `include(get_raft.cmake)`. This will make available several targets to add to configure the link libraries for your artifacts. - -```cmake - -set(RAFT_VERSION "22.12") -set(RAFT_FORK "rapidsai") -set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") - -function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARIES) - cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" - "${multiValueArgs}" ${ARGN} ) - - #----------------------------------------------------- - # Invoke CPM find_package() - #----------------------------------------------------- - - rapids_cpm_find(raft ${PKG_VERSION} - GLOBAL_TARGETS raft::raft - BUILD_EXPORT_SET projname-exports - INSTALL_EXPORT_SET projname-exports - CPM_ARGS - GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git - GIT_TAG ${PKG_PINNED_TAG} - SOURCE_SUBDIR cpp - OPTIONS - "BUILD_TESTS OFF" - "BUILD_BENCH OFF" - "RAFT_COMPILE_LIBRARIES ${PKG_COMPILE_LIBRARIES}" - ) - -endfunction() - -# Change pinned tag here to test a commit in CI -# To use a different RAFT locally, set the CMake variable -# CPM_raft_SOURCE=/path/to/local/raft -find_and_configure_raft(VERSION ${RAFT_VERSION}.00 - FORK ${RAFT_FORK} - PINNED_TAG ${RAFT_PINNED_TAG} - COMPILE_LIBRARIES NO -) -``` +RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library, which makes it easy to include in downstream cmake projects. RAPIDS-CMake provides a convenience layer around CPM. Please refer to [these instructions](https://github.com/rapidsai/rapids-cmake#installation) to install and use rapids-cmake in your project. + +#### Example Template Project + +You can find an [example RAFT](cpp/template/README.md) project template in the `cpp/template` directory, which demonstrates how to build a new application with RAFT or incorporate RAFT into an existing cmake project. + +#### CMake Targets -Several CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. +Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. | Component | Target | Description | Base Dependencies | |-------------|---------------------|-----------------------------------------------------------|---------------------------------------| @@ -317,6 +283,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - `src`: Compiled APIs and template specializations for the shared libraries + - `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT. - `test`: Googletests source code - `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs) - `python`: Source code for Python libraries. diff --git a/build.sh b/build.sh index b5a72f4205..9468d2cab0 100755 --- a/build.sh +++ b/build.sh @@ -18,7 +18,7 @@ ARGS=$* # script, and that this script resides in the repo dir! REPODIR=$(cd $(dirname $0); pwd) -VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean --uninstall -v -g -n --compile-lib --allgpuarch --no-nvtx --show_depr_warn -h" +VALIDARGS="clean libraft pylibraft raft-dask docs tests bench template clean --uninstall -v -g -n --compile-lib --allgpuarch --no-nvtx --show_depr_warn -h" HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool=] [--limit-tests=] [--limit-bench=] where is: clean - remove all existing build artifacts and configuration (start over) @@ -29,6 +29,7 @@ HELP="$0 [ ...] [ ...] [--cmake-args=\"\"] [--cache-tool= is: -v - verbose build mode @@ -354,13 +355,12 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has -DCMAKE_INSTALL_PREFIX=${INSTALL_PREFIX} \ -DCMAKE_CUDA_ARCHITECTURES=${RAFT_CMAKE_CUDA_ARCHITECTURES} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ - -DRAFT_COMPILE_LIBRARIES=${COMPILE_LIBRARIES} \ + -DRAFT_COMPILE_LIBRARY=${COMPILE_LIBRARY} \ -DRAFT_NVTX=${NVTX} \ -DDISABLE_DEPRECATION_WARNINGS=${DISABLE_DEPRECATION_WARNINGS} \ -DBUILD_TESTS=${BUILD_TESTS} \ -DBUILD_BENCH=${BUILD_BENCH} \ -DCMAKE_MESSAGE_LOG_LEVEL=${CMAKE_LOG_LEVEL} \ - -DRAFT_COMPILE_LIBRARY=${COMPILE_LIBRARY} \ ${CACHE_ARGS} \ ${EXTRA_CMAKE_ARGS} @@ -410,3 +410,12 @@ if hasArg docs; then cd ${SPHINX_BUILD_DIR} sphinx-build -b html source _html fi + +################################################################################ +# Initiate build for example RAFT application template (if needed) + +if hasArg template; then + pushd cpp/template + ./build.sh + popd +fi diff --git a/ci/release/apply_wheel_modifications.sh b/ci/release/apply_wheel_modifications.sh index ed3d2a15fd..efc8f0c77c 100755 --- a/ci/release/apply_wheel_modifications.sh +++ b/ci/release/apply_wheel_modifications.sh @@ -6,10 +6,6 @@ VERSION=${1} CUDA_SUFFIX=${2} -# __init__.py versions -sed -i "s/__version__ = .*/__version__ = \"${VERSION}\"/g" python/pylibraft/pylibraft/__init__.py -sed -i "s/__version__ = .*/__version__ = \"${VERSION}\"/g" python/raft-dask/raft_dask/__init__.py - # pyproject.toml versions sed -i "s/^version = .*/version = \"${VERSION}\"/g" python/pylibraft/pyproject.toml sed -i "s/^version = .*/version = \"${VERSION}\"/g" python/raft-dask/pyproject.toml diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py index 32c13e61ca..5709ac901c 100644 --- a/ci/wheel_smoke_test_raft_dask.py +++ b/ci/wheel_smoke_test_raft_dask.py @@ -1,4 +1,19 @@ -from dask.distributed import Client, wait +# Copyright (c) 2019-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from dask.distributed import Client, get_worker, wait from dask_cuda import LocalCUDACluster, initialize from raft_dask.common import ( @@ -23,12 +38,12 @@ def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_send_recv(handle, n_trials) def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return func(handle, root) diff --git a/conda/environments/all_cuda-118_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml index 39f1fef4d5..1afebc98e6 100644 --- a/conda/environments/all_cuda-118_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -18,13 +18,14 @@ dependencies: - cupy - cxx-compiler - cython>=0.29,<0.30 -- dask-cuda=23.04 +- dask-cuda==23.4.* - dask>=2023.1.1 - distributed>=2023.1.1 - doxygen>=1.8.20 - gcc_linux-64=11.* - graphviz - ipython +- joblib>=0.11 - libcublas-dev=11.11.3.6 - libcublas=11.11.3.6 - libcurand-dev=10.3.0.86 @@ -33,13 +34,16 @@ dependencies: - libcusolver=11.4.1.48 - libcusparse-dev=11.7.5.86 - libcusparse=11.7.5.86 +- nccl>=2.9.9 - ninja +- numba>=0.49 +- numpy>=1.21 - numpydoc - pydata-sphinx-theme - pytest - pytest-cov - recommonmark -- rmm=23.04 +- rmm==23.4.* - scikit-build>=0.13.1 - scikit-learn - scipy @@ -47,6 +51,6 @@ dependencies: - sphinx-markdown-tables - sysroot_linux-64==2.17 - ucx-proc=*=gpu -- ucx-py=0.31.* +- ucx-py==0.31.* - ucx>=1.13.0 name: all_cuda-118_arch-x86_64 diff --git a/conda/recipes/libraft/build_libraft_template.sh b/conda/recipes/libraft/build_libraft_template.sh new file mode 100644 index 0000000000..9759402884 --- /dev/null +++ b/conda/recipes/libraft/build_libraft_template.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +# Just building template so we verify it uses libraft.so and fail if it doesn't build +./build.sh template \ No newline at end of file diff --git a/conda/recipes/libraft/meta.yaml b/conda/recipes/libraft/meta.yaml index 2a724672ab..f911166a9a 100644 --- a/conda/recipes/libraft/meta.yaml +++ b/conda/recipes/libraft/meta.yaml @@ -150,3 +150,39 @@ outputs: home: https://rapids.ai/ license: Apache-2.0 summary: libraft tests + - name: libraft-template + version: {{ version }} + script: build_libraft_template.sh + build: + script_env: *script_env + number: {{ GIT_DESCRIBE_NUMBER }} + string: cuda{{ cuda_major }}_{{ date_string }}_{{ GIT_DESCRIBE_HASH }}_{{ GIT_DESCRIBE_NUMBER }} + ignore_run_exports_from: + - {{ compiler('cuda') }} + requirements: + build: + - {{ compiler('c') }} + - {{ compiler('cuda') }} {{ cuda_version }} + - {{ compiler('cxx') }} + - cmake {{ cmake_version }} + - ninja + - sysroot_{{ target_platform }} {{ sysroot_version }} + host: + - {{ pin_subpackage('libraft', exact=True) }} + - {{ pin_subpackage('libraft-headers', exact=True) }} + - cuda-profiler-api {{ cuda_profiler_api_host_version }} + - libcublas {{ libcublas_host_version }} + - libcublas-dev {{ libcublas_host_version }} + - libcurand {{ libcurand_host_version }} + - libcurand-dev {{ libcurand_host_version }} + - libcusolver {{ libcusolver_host_version }} + - libcusolver-dev {{ libcusolver_host_version }} + - libcusparse {{ libcusparse_host_version }} + - libcusparse-dev {{ libcusparse_host_version }} + run: + - {{ pin_subpackage('libraft', exact=True) }} + - {{ pin_subpackage('libraft-headers', exact=True) }} + about: + home: https://rapids.ai/ + license: Apache-2.0 + summary: libraft template diff --git a/conda/recipes/pylibraft/meta.yaml b/conda/recipes/pylibraft/meta.yaml index a528064348..7730801801 100644 --- a/conda/recipes/pylibraft/meta.yaml +++ b/conda/recipes/pylibraft/meta.yaml @@ -36,6 +36,7 @@ requirements: - cython >=0.29,<0.30 - libraft {{ version }} - libraft-headers {{ version }} + - numpy >=1.21 - python x.x - rmm ={{ minor_version }} - scikit-build >=0.13.1 diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index ef22522116..778b187870 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -11,7 +11,7 @@ sysroot_version: - "2.17" ucx_version: - - "1.13.0" + - ">=1.13.0,<1.15.0" ucx_py_version: - "0.31.*" diff --git a/conda/recipes/raft-dask/meta.yaml b/conda/recipes/raft-dask/meta.yaml index b387f0f47c..59a67fe148 100644 --- a/conda/recipes/raft-dask/meta.yaml +++ b/conda/recipes/raft-dask/meta.yaml @@ -54,7 +54,7 @@ requirements: - pylibraft {{ version }} - python x.x - rmm ={{ minor_version }} - - ucx >={{ ucx_version }} + - ucx {{ ucx_version }} - ucx-proc=*=gpu - ucx-py {{ ucx_py_version }} diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index bdaacb4a85..c1704552ec 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -304,6 +304,7 @@ if(RAFT_COMPILE_LIBRARY) # These are somehow missing a kernel definition which is causing a compile error. # src/distance/specializations/detail/kernels/rbf_kernel_double.cu # src/distance/specializations/detail/kernels/rbf_kernel_float.cu + src/neighbors/brute_force_knn_int64_t_float.cu src/distance/specializations/detail/kernels/tanh_kernel_double.cu src/distance/specializations/detail/kernels/tanh_kernel_float.cu src/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -312,10 +313,6 @@ if(RAFT_COMPILE_LIBRARY) src/distance/specializations/detail/l1_double_double_double_int.cu src/distance/specializations/detail/l2_expanded_float_float_float_int.cu src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu src/distance/specializations/detail/l_inf_double_double_double_int.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 8049074c09..d92ccba8e3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -72,6 +72,11 @@ if(BUILD_BENCH) OPTIONAL LIB ) + ConfigureBench( + NAME TUNE_DISTANCE PATH bench/distance/tune_pairwise/kernel.cu + bench/distance/tune_pairwise/bench.cu bench/main.cpp + ) + ConfigureBench( NAME DISTANCE_BENCH diff --git a/cpp/bench/distance/tune_pairwise/bench.cu b/cpp/bench/distance/tune_pairwise/bench.cu new file mode 100644 index 0000000000..87159ab1b1 --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/bench.cu @@ -0,0 +1,151 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Tuning benchmarks. +// +// Goals: +// +// 1. Fast compile times to maintain iteration speed. +// 2. Create benchmarks that can inform the design of the kernels. +// +// Non-goals: +// +// 1. Measure every distance operation. Instead measures just one distance +// operation at the same time. +// 2. Be useful for finding performance regressions. This is handled by the +// normal benchmarks. +// +// So far, both goals are partly achieved. +// +// RE (1), COMPILE TIMES: kernel.cu is fast to compile. This file is not. +// When the internals of a pairwise distance kernel is changed, this file is not +// recompiled. +// +// RE 2, benchmarks with intent: this file contains a benchmark to check the +// maximal throughput of a kernel. Measuring other things, like performance on +// skinny or wide matrices is not yet implemented. + +#include "kernel.cuh" // launch_kernel +#include // std::min +#include // RAFT_BENCH_REGISTER +#include // pairwise_matrix_params +#include // rmm::device_uvector +#include // std::vector + +namespace raft::bench::distance::tune { + +// Max throughput benchmark. +// +// Goal: Measure the maximum distances/sec that can be computed. +// +// To achieve this, we make sure that: +// +// - Input data size is a multiple of the block tile size. +// +// - Perfect distribution of work between SMs, i.e. the number of block tiles is +// a large multiple (num_waves) of the number of blocks (#SMs * occupancy). +// +// - Multiple iterations over Kblk are executed (num_k_iters). +struct throughput_param { + int num_waves; + int occupancy; + int num_k_iters; +}; + +const std::vector throughput_params{ + // 32 waves, requested occupancy of 4, and 32 k iterations typically achieves + // maximum throughput. No need to pick higher values. + {32, 4, 32}, +}; + +struct throughput_bench : public fixture { + const throughput_param p; + + throughput_bench(const throughput_param& p_) : p(p_) {} + + void run_benchmark(::benchmark::State& state) override + { + // Get block size: + int block_m, block_n, block_k; + get_block_size(block_m, block_n, block_k); + + // Determine number of blocks that will be launched. This informs the size + // of the inputs as well as the grid size. + const int num_sms = raft::getMultiProcessorCount(); + const int max_occupancy = get_max_occupancy(); + const int occupancy = std::min(p.occupancy, max_occupancy); + const int num_blocks = occupancy * num_sms; + dim3 grid(num_blocks); + + // Create input sizes that are a multiple of the block tile size. + size_t m = block_m; + size_t n = block_n * p.num_waves * num_blocks; + size_t k = block_k * p.num_k_iters; + + // DataT, OutT, IdxT, etc, are defined in tuned_kernel.cuh + rmm::device_uvector x_vec(m * k, stream); + rmm::device_uvector y_vec(n * k, stream); + rmm::device_uvector x_norm_vec(m, stream); + rmm::device_uvector y_norm_vec(n, stream); + rmm::device_uvector out_vec(m * n, stream); + + auto x = x_vec.data(); + auto y = y_vec.data(); + auto x_norm = x_norm_vec.data(); + auto y_norm = y_norm_vec.data(); + auto out = out_vec.data(); + FinOpT fin_op{}; + + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = row_major ? k : m; + IdxT ldy = row_major ? k : n; + IdxT ld_out = row_major ? n : m; + + // Template parameters of pairwise_matrix_params are defined in kernel.cuh + pairwise_matrix_params kparams{ + IdxT(m), IdxT(n), IdxT(k), ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, row_major}; + + // Run benchmark + loop_on_state(state, [&]() { launch_kernel(kparams, grid, stream); }); + + // Report metrics. We don't report flop/s because we do not know for each + // distance operation how many flops it costs. For L2_unexp and l1, we can + // double this number to get the flop/s. For l2 expanded, core_ops/s should + // equal flop/s (modulo the sqrt and subtracting from the norm). + size_t num_core_ops = m * n * k; + size_t read_elts = n * k + m * k; + size_t write_elts = m * n; + + state.counters["m"] = benchmark::Counter(m); + state.counters["n"] = benchmark::Counter(n); + state.counters["k"] = benchmark::Counter(k); + state.counters["occupancy"] = benchmark::Counter(occupancy); + state.counters["# waves"] = benchmark::Counter(p.num_waves); + state.counters["# k iters"] = benchmark::Counter(p.num_k_iters); + + state.counters["core_ops/s"] = benchmark::Counter(num_core_ops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW"] = benchmark::Counter(write_elts * sizeof(OutT) + read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + } +}; + +RAFT_BENCH_REGISTER(throughput_bench, "", throughput_params); + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cu b/cpp/bench/distance/tune_pairwise/kernel.cu new file mode 100644 index 0000000000..3112e1ea9a --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cu @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "kernel.cuh" +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 +#include // raft::util::arch::SM_compute_arch + +namespace raft::bench::distance::tune { + +// Distance op +using OpT = raft::distance::detail::ops::lp_unexp_distance_op; +constexpr float metric_arg = 2.0; +OpT distance_op{metric_arg}; + +// Kernel policy +constexpr int vec_len = 1; +using Policy = typename raft::linalg::Policy4x4::Policy; + +// Architecture +namespace arch = raft::util::arch; +constexpr auto sm_compat_range = arch::SM_range(arch::SM_min(), arch::SM_future()); + +void launch_kernel(pairwise_matrix_params params, dim3 grid, cudaStream_t stream) +{ + dim3 block(Policy::Nthreads); + int smem_size = OpT::shared_mem_size(); + + // Obtain function pointer to kernel + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + + kernel<<>>(distance_op, params); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +void get_block_size(int& m, int& n, int& k) +{ + m = Policy::Mblk; + n = Policy::Nblk; + k = Policy::Kblk; +} + +void* get_kernel_ptr() +{ + auto kernel = raft::distance::detail::pairwise_matrix_kernel; + return reinterpret_cast(kernel); +} + +int get_max_occupancy() +{ + void* kernel_ptr = get_kernel_ptr(); + int max_occupancy; + int smem_size = OpT::shared_mem_size(); + + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_occupancy, kernel_ptr, Policy::Nthreads, smem_size)); + + return max_occupancy; +} + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/distance/tune_pairwise/kernel.cuh b/cpp/bench/distance/tune_pairwise/kernel.cuh new file mode 100644 index 0000000000..5da54a343c --- /dev/null +++ b/cpp/bench/distance/tune_pairwise/kernel.cuh @@ -0,0 +1,44 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // lp_unexp_distance_op +#include // pairwise_matrix_params + +namespace raft::bench::distance::tune { + +// Launch one specific kernel with the following template parameters +constexpr bool row_major = true; +using DataT = float; +using AccT = float; +using OutT = DataT; +using IdxT = int; + +using FinOpT = raft::identity_op; + +using pairwise_matrix_params = + raft::distance::detail::pairwise_matrix_params; + +// Launches kernel +void launch_kernel(pairwise_matrix_params, dim3, cudaStream_t); + +// Describes the block size that is decided by the policy +void get_block_size(int& m, int& n, int& k); + +int get_max_occupancy(); + +} // namespace raft::bench::distance::tune diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu index d4873e2640..870119db52 100644 --- a/cpp/bench/matrix/select_k.cu +++ b/cpp/bench/matrix/select_k.cu @@ -35,6 +35,10 @@ #include #include +#include +#include +#include + namespace raft::matrix { using namespace raft::bench; // NOLINT @@ -50,7 +54,23 @@ struct selection : public fixture { { raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + + KeyT min_value = -1.0; + KeyT max_value = 1.0; + if (p.use_same_leading_bits) { + if constexpr (std::is_same_v) { + uint32_t min_bits = 0x3F800000; // 1.0 + uint32_t max_bits = 0x3F8000FF; // 1.00003 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } else if constexpr (std::is_same_v) { + uint64_t min_bits = 0x3FF0000000000000; // 1.0 + uint64_t max_bits = 0x3FF0000FFFFFFFFF; // 1.000015 + memcpy(&min_value, &min_bits, sizeof(KeyT)); + memcpy(&max_value, &max_bits, sizeof(KeyT)); + } + } + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), min_value, max_value); } void run_benchmark(::benchmark::State& state) override // NOLINT @@ -60,6 +80,7 @@ struct selection : public fixture { try { std::ostringstream label_stream; label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + if (params_.use_same_leading_bits) { label_stream << "#same-leading-bits"; } state.SetLabel(label_stream.str()); loop_on_state(state, [this, &handle]() { select::select_k_impl(handle, @@ -85,21 +106,55 @@ struct selection : public fixture { }; const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, + {20000, 500, 1, true}, + {20000, 500, 2, true}, + {20000, 500, 4, true}, + {20000, 500, 8, true}, + {20000, 500, 16, true}, + {20000, 500, 32, true}, + {20000, 500, 64, true}, + {20000, 500, 128, true}, + {20000, 500, 256, true}, + + {1000, 10000, 1, true}, + {1000, 10000, 2, true}, + {1000, 10000, 4, true}, + {1000, 10000, 8, true}, + {1000, 10000, 16, true}, + {1000, 10000, 32, true}, + {1000, 10000, 64, true}, + {1000, 10000, 128, true}, + {1000, 10000, 256, true}, + + {100, 100000, 1, true}, + {100, 100000, 2, true}, + {100, 100000, 4, true}, + {100, 100000, 8, true}, + {100, 100000, 16, true}, + {100, 100000, 32, true}, + {100, 100000, 64, true}, + {100, 100000, 128, true}, + {100, 100000, 256, true}, + + {10, 1000000, 1, true}, + {10, 1000000, 2, true}, + {10, 1000000, 4, true}, + {10, 1000000, 8, true}, + {10, 1000000, 16, true}, + {10, 1000000, 32, true}, + {10, 1000000, 64, true}, + {10, 1000000, 128, true}, + {10, 1000000, 256, true}, + + {10, 1000000, 1, true, false, true}, + {10, 1000000, 2, true, false, true}, + {10, 1000000, 4, true, false, true}, + {10, 1000000, 8, true, false, true}, + {10, 1000000, 16, true, false, true}, + {10, 1000000, 32, true, false, true}, + {10, 1000000, 64, true, false, true}, + {10, 1000000, 128, true, false, true}, + {10, 1000000, 256, true, false, true}, }; #define SELECTION_REGISTER(KeyT, IdxT, A) \ @@ -109,24 +164,27 @@ const std::vector kInputs{ RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ } -SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT - -SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT - -SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT -SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT +SELECTION_REGISTER(float, uint32_t, kPublicApi); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, uint32_t, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, uint32_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, uint32_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, uint32_t, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, int64_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int64_t, kRadix11bitsExtraPass); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, int64_t, kWarpDistributedShm); // NOLINT } // namespace raft::matrix diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index 8d3321eb77..192d160d45 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -20,7 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include -#include +#include // raft::shfl_xor #endif namespace raft { /** diff --git a/cpp/include/raft/distance/detail/distance.cuh b/cpp/include/raft/distance/detail/distance.cuh index f469250b45..7493c4e558 100644 --- a/cpp/include/raft/distance/detail/distance.cuh +++ b/cpp/include/raft/distance/detail/distance.cuh @@ -16,25 +16,18 @@ #pragma once -#include -#include - -#include -#include -#include -#include -#include - #include - +#include #include #include - +#include +#include #include #include -#include -#include -#include +#include +#include +#include +#include namespace raft { namespace distance { @@ -140,14 +133,14 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - AccT* norm_col_vec = workspace; - AccT* norm_row_vec = workspace; - AccT* sq_norm_col_vec = workspace; - AccT* sq_norm_row_vec = workspace; + AccT* x_norm = workspace; + AccT* y_norm = workspace; + AccT* sq_x_norm = workspace; + AccT* sq_y_norm = workspace; if (x != y) { - norm_row_vec += m; + y_norm += m; - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -158,7 +151,7 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - raft::linalg::reduce(norm_row_vec, + raft::linalg::reduce(y_norm, y, k, n, @@ -170,12 +163,12 @@ void distance_impl(raft::resources const& handle, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += (m + n); - sq_norm_row_vec = sq_norm_col_vec + m; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); - raft::linalg::rowNorm(sq_norm_row_vec, y, k, n, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += (m + n); + sq_y_norm = sq_x_norm + m; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + raft::linalg::rowNorm(sq_y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream); } else { - raft::linalg::reduce(norm_col_vec, + raft::linalg::reduce(x_norm, x, k, m, @@ -186,15 +179,15 @@ void distance_impl(raft::resources const& handle, false, raft::identity_op(), raft::add_op()); - sq_norm_col_vec += m; - sq_norm_row_vec = sq_norm_col_vec; - raft::linalg::rowNorm(sq_norm_col_vec, x, k, m, raft::linalg::L2Norm, is_row_major, stream); + sq_x_norm += m; + sq_y_norm = sq_x_norm; + raft::linalg::rowNorm(sq_x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream); } using OpT = ops::correlation_distance_op; - OpT corr_op(is_row_major, sq_norm_col_vec, sq_norm_row_vec, m, n, k); + OpT corr_op(is_row_major, sq_x_norm, sq_y_norm, m, n, k); pairwise_matrix_dispatch( - corr_op, m, n, k, x, y, norm_col_vec, norm_row_vec, out, fin_op, stream, is_row_major); + corr_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -223,22 +216,22 @@ void distance_impl(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::sqrt_op{}); } ops::cosine_distance_op distance_op{}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -389,10 +382,6 @@ void distance_impl(raft::resources const& handle, return (!x_zero) * raft::exp(input); }; - // This op takes some shortcuts when x equals y. So its behavior changes based - // on this. - ops::kl_divergence_op kl_divergence{is_row_major, x == y}; - if (x != y) { raft::linalg::unaryOp( (DataT*)y, y, n * k, unaryOp_lambda, stream); @@ -401,8 +390,12 @@ void distance_impl(raft::resources const& handle, const DataT* x_norm = nullptr; const DataT* y_norm = nullptr; - pairwise_matrix_dispatch( - kl_divergence, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); + // This op takes some shortcuts when x equals y. So its behavior changes based + // on this. + ops::kl_divergence_op distance_op{is_row_major, x == y}; + + pairwise_matrix_dispatch( + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); if (x != y) { // Now reverse previous log (x) back to x using (e ^ log(x)) @@ -464,22 +457,22 @@ void distance_impl_l2_expanded( // NOTE: different name "workspace size error"); ASSERT(workspace != nullptr, "workspace is null"); - DataT* norm_A = workspace; - DataT* norm_B = workspace; + DataT* x_norm = workspace; + DataT* y_norm = workspace; if (x != y) { - norm_B += m; + y_norm += m; raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); raft::linalg::rowNorm( - norm_B, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + y_norm, y, k, n, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } else { raft::linalg::rowNorm( - norm_A, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); + x_norm, x, k, m, raft::linalg::L2Norm, is_row_major, stream, raft::identity_op{}); } ops::l2_exp_distance_op distance_op{perform_sqrt}; pairwise_matrix_dispatch( - distance_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + distance_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -543,13 +536,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template @@ -571,13 +564,13 @@ void distance_impl(raft::resources const& handle, ops::l2_unexp_distance_op l2_op(perform_sqrt); // The unexpanded L2 does not require the norms of a and b to be calculated. - const DataT* norm_A = nullptr; - const DataT* norm_B = nullptr; + const DataT* x_norm = nullptr; + const DataT* y_norm = nullptr; cudaStream_t stream = raft::resource::get_cuda_stream(handle); pairwise_matrix_dispatch( - l2_op, m, n, k, x, y, norm_A, norm_B, out, fin_op, stream, is_row_major); + l2_op, m, n, k, x, y, x_norm, y_norm, out, fin_op, stream, is_row_major); } template diff --git a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh index 930294ce31..eaf37b7e9c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/canberra.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/canberra.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // raft::abs +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +43,7 @@ struct canberra_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh index 289b69070a..4fc4bb8297 100644 --- a/cpp/include/raft/distance/detail/distance_ops/correlation.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/correlation.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -61,7 +61,7 @@ struct correlation_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + (2 * (Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } diff --git a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh index 7c37c27b4e..0883136c9f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cosine.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cosine.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -26,7 +26,7 @@ struct cosine_cutlass_op { __device__ cosine_cutlass_op() noexcept {} __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { - return static_cast(1.0) - (AccT)(accVal / (aNorm * bNorm)); + return static_cast(1.0) - static_cast(accVal / (aNorm * bNorm)); } __device__ AccT operator()(DataT aData) const noexcept { return aData; } }; @@ -53,7 +53,7 @@ struct cosine_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -76,7 +76,10 @@ struct cosine_distance_op { } } - cosine_cutlass_op get_cutlass_op() { return cosine_cutlass_op(); } + constexpr cosine_cutlass_op get_cutlass_op() const + { + return cosine_cutlass_op(); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh index d3eb90467b..7a4fe0ce83 100644 --- a/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/cutlass.cuh @@ -16,7 +16,8 @@ #pragma once -#include +#include // std::false_type +#include // std::declval namespace raft::distance::detail::ops { @@ -34,7 +35,8 @@ struct has_cutlass_op : std::false_type { // Specialization recognizes types that do support CUTLASS template -struct has_cutlass_op> : std::true_type { +struct has_cutlass_op().get_cutlass_op())>> + : std::true_type { }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh index 1cfdcfdc73..475b8892e9 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hamming.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hamming.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -45,7 +45,7 @@ struct hamming_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh index c4aecc7a6f..0489b45854 100644 --- a/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/hellinger.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +42,7 @@ struct hellinger_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh index 41eeb9dd83..e46c63734c 100644 --- a/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/jensen_shannon.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::log +#include // DI namespace raft::distance::detail::ops { @@ -44,7 +45,7 @@ struct jensen_shannon_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh index d046b62c30..d083c5ddcc 100644 --- a/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/kl_divergence.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::log +#include // DI namespace raft::distance::detail::ops { @@ -49,7 +50,7 @@ struct kl_divergence_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l1.cuh b/cpp/include/raft/distance/detail/distance_ops/l1.cuh index 8ec4000827..7e86fd3603 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l1.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l1.cuh @@ -15,7 +15,7 @@ */ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -41,7 +41,7 @@ struct l1_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 2a7af53813..95577fd311 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -54,7 +54,7 @@ struct l2_exp_distance_op { using AccT = AccType; using IdxT = IdxType; - bool sqrt; + const bool sqrt; l2_exp_distance_op(bool sqrt_) noexcept : sqrt(sqrt_) {} @@ -67,7 +67,7 @@ struct l2_exp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); } @@ -102,7 +102,10 @@ struct l2_exp_distance_op { } } - l2_exp_cutlass_op get_cutlass_op() { return l2_exp_cutlass_op(sqrt); } + constexpr l2_exp_cutlass_op get_cutlass_op() const + { + return l2_exp_cutlass_op(sqrt); + } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh index f0ea591eaf..62c212ee8f 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_unexp.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -46,7 +46,7 @@ struct l2_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh index fb21fb1a21..88853a3083 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l_inf.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,7 +42,7 @@ struct l_inf_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh index 71dfd51a6e..290f4af1b4 100644 --- a/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/lp_unexp.cuh @@ -15,7 +15,8 @@ */ #pragma once -#include +#include // raft::pow, raft::abs +#include // DI namespace raft::distance::detail::ops { @@ -45,7 +46,7 @@ struct lp_unexp_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh index ea09e4d1db..63dbf350d1 100644 --- a/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/russel_rao.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -47,7 +47,7 @@ struct russel_rao_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. template - constexpr size_t shared_mem_size() + static constexpr size_t shared_mem_size() { return Policy::SmemSize; } diff --git a/cpp/include/raft/distance/detail/distance_ops/template.cuh b/cpp/include/raft/distance/detail/distance_ops/template.cuh index 6998f3cad4..4320068361 100644 --- a/cpp/include/raft/distance/detail/distance_ops/template.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/template.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include // DI namespace raft::distance::detail::ops { @@ -42,8 +42,8 @@ struct template_distance_op { // Size of shared memory. This is normally decided by the kernel policy, but // some ops such as correlation_distance_op use more. - template - constexpr size_t shared_mem_size() + template + static constexpr size_t shared_mem_size() { return Policy::SmemSize + TODO; } @@ -59,6 +59,10 @@ struct template_distance_op { { TODO; } + + // If exist, returns a cutlass op that performs the same operation. + // See cosine and l2_exp distance ops for an example. + constexpr l2_exp_cutlass_op get_cutlass_op() const { TODO; } }; } // namespace raft::distance::detail::ops diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 8fbd7a9c69..be6fed9f10 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -16,23 +16,20 @@ #pragma once -#include -#include -#include -#include -#include -#include +#include // size_t +#include // std::numeric_limits +#include // raft::KeyValuePair +#include // raft::identity_op +#include // ops::l2_exp_distance_op +#include // PairwiseDistances +#include // Policy +#include // raft::ceildiv, raft::shfl namespace raft { namespace distance { namespace detail { -#if (ENABLE_MEMCPY_ASYNC == 1) -#include -using namespace nvcuda::experimental; -#endif - template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; @@ -124,11 +121,10 @@ DI void updateReducedVal( template __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, const DataT* x, @@ -142,7 +138,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, int* mutex, ReduceOpT redOp, KVPReduceOpT pairRedOp, - CoreLambda core_op, + OpT distance_op, FinalLambda fin_op) { extern __shared__ char smem[]; @@ -163,24 +159,6 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, IdxT gridStrideY) { KVPReduceOpT pairRed_op(pairRedOp); -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } - } - if (Sqrt) { -#pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; - } - } - } - // intra thread reduce const auto acccolid = threadIdx.x % P::AccThCols; const auto accrowid = threadIdx.x / P::AccThCols; @@ -229,18 +207,18 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, }; IdxT lda = k, ldb = k, ldd = n; - PairwiseDistances + row_major, + write_out> obj(x, y, m, @@ -251,9 +229,9 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, ldd, xn, yn, - nullptr, + nullptr, // Output pointer smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -289,9 +267,6 @@ void fusedL2NNImpl(OutT* min, constexpr auto maxVal = std::numeric_limits::max(); typedef KeyValuePair KVPair; - // Accumulation operation lambda - auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; - RAFT_CUDA_TRY(cudaMemsetAsync(workspace, 0, sizeof(int) * m, stream)); if (initOutBuffer) { initKernel @@ -300,59 +275,25 @@ void fusedL2NNImpl(OutT* min, } constexpr size_t shmemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); - if (sqrt) { - auto fusedL2NNSqrt = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NNSqrt); - - fusedL2NNSqrt<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } else { - auto fusedL2NN = fusedL2NNkernel; - dim3 grid = launchConfigGenerator

(m, n, shmemSize, fusedL2NN); - fusedL2NN<<>>(min, - x, - y, - xn, - yn, - m, - n, - k, - maxVal, - workspace, - redOp, - pairRedOp, - core_lambda, - raft::identity_op{}); - } + using AccT = DataT; + ops::l2_exp_distance_op distance_op{sqrt}; + + raft::identity_op fin_op{}; + + auto kernel = fusedL2NNkernel; + + dim3 grid = launchConfigGenerator

(m, n, shmemSize, kernel); + + kernel<<>>( + min, x, y, xn, yn, m, n, k, maxVal, workspace, redOp, pairRedOp, distance_op, fin_op); RAFT_CUDA_TRY(cudaGetLastError()); } diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 0293f10c29..c6b09be31e 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -14,14 +14,11 @@ * limitations under the License. */ #pragma once -#include -#include -#include -#include -#include -#include +#include // raft::linalg::Contractions_NT +#include // ceildiv +#include // RAFT_CUDA_TRY -#include +#include // size_t namespace raft { namespace distance { @@ -29,16 +26,12 @@ namespace detail { /** * @brief Device class for L1, L2 and cosine distance metrics. - * @tparam useNorms whether norms are needed * @tparam DataT input data-type (for A and B matrices) * @tparam AccT accumulation data-type * @tparam OutT output data-type (for C and D matrices) * @tparam IdxT index data-type * @tparam Policy struct which tunes the Contraction kernel - * @tparam CoreLambda tells how to accumulate an x and y into - acc. its signature: - template void core_lambda(AccT& acc, - const DataT& x, const DataT& y) + * @tparam OpT A distance operation, e.g., cosine_distance_op. * @tparam EpilogueLambda applies an elementwise function to compute final values. Its signature is: template void epilogue_lambda @@ -56,19 +49,17 @@ namespace detail { * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine * @param[output] pD output matrix * @param[in] smem shared mem buffer for intermediate storage of A, B, xn & yn. - * @param core_op the core accumulation operation lambda + * @param distance_op the distance operation, e.g. cosine_distance_op * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ -template > struct PairwiseDistances : public BaseClass { + // Get accumulation type from distance_op + using AccT = typename OpT::AccT; + private: typedef Policy P; const DataT* xn; @@ -83,7 +77,7 @@ struct PairwiseDistances : public BaseClass { const DataT* const yBase; OutT* dOutput; char* smem; - CoreLambda core_op; + OpT distance_op; EpilogueLambda epilog_op; FinalLambda fin_op; rowEpilogueLambda rowEpilog_op; @@ -109,7 +103,7 @@ struct PairwiseDistances : public BaseClass { const DataT* _yn, OutT* _dOutput, char* _smem, - CoreLambda _core_op, + OpT _distance_op, EpilogueLambda _epilog_op, FinalLambda _fin_op, rowEpilogueLambda _rowEpilog_op) @@ -119,7 +113,7 @@ struct PairwiseDistances : public BaseClass { yBase(_y), dOutput(_dOutput), smem(_smem), - core_op(_core_op), + distance_op(_distance_op), epilog_op(_epilog_op), fin_op(_fin_op), rowEpilog_op(_rowEpilog_op), @@ -159,15 +153,25 @@ struct PairwiseDistances : public BaseClass { this->switch_read_buffer(); // Epilog: - if (useNorms) { + if (distance_op.use_norms) { DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; load_norms(tile_idx_m, tile_idx_n, regxn, regyn); // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, regxn, regyn, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); } else { // Overlap ldg with epilog computation ldgNextGridStride(tile_idx_m, tile_idx_n); + // Calculate distance_op epilog. + // Use .template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + distance_op.template epilog(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + // And any possible additional epilogs epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); } if (writeOut) { store_output(tile_idx_m, tile_idx_n); } @@ -201,24 +205,41 @@ struct PairwiseDistances : public BaseClass { } } - DI void accumulate() + DI void accumulate_reg_tile(DataT (®_x)[P::AccRowsPerTh][P::Veclen], + DataT (®_y)[P::AccColsPerTh][P::Veclen]) { #pragma unroll - for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { - this->ldsXY(ki); + for (int v = 0; v < P::Veclen; ++v) { #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { -#pragma unroll - for (int v = 0; v < P::Veclen; ++v) { - core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); - } + distance_op.core(acc[i][j], reg_x[i][v], reg_y[j][v]); } } } } + DI void accumulate() + { + // We have a separate ldsXY and accumulate_reg_tile outside the loop body, + // so that these separated calls can be interspersed with preceding and + // following instructions, thereby hiding latency. + this->ldsXY(0); + + // If expensive inner loop, do not unroll loop. + constexpr int num_iterations = P::Kblk / P::Veclen - 1; + constexpr int unroll_count = decltype(distance_op)::expensive_inner_loop ? 1 : num_iterations; +#pragma unroll unroll_count + for (int ki = P::Veclen; ki < P::Kblk; ki += P::Veclen) { + accumulate_reg_tile(this->regx, this->regy); + this->ldsXY(ki); + } + + // Accumulate last loaded tile. + accumulate_reg_tile(this->regx, this->regy); + } + DI void load_norms(IdxT tile_idx_m, IdxT tile_idx_n, DataT (®xn)[P::AccRowsPerTh], @@ -274,7 +295,11 @@ struct PairwiseDistances : public BaseClass { template dim3 launchConfigGenerator(IdxT m, IdxT n, std::size_t sMemSize, T func) { - const auto numSMs = raft::getMultiProcessorCount(); + int devId; + RAFT_CUDA_TRY(cudaGetDevice(&devId)); + int numSMs; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&numSMs, cudaDevAttrMultiProcessorCount, devId)); + int numBlocksPerSm = 0; dim3 grid; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh index c5fdd28117..efcd5d9389 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_cutlass_base.cuh @@ -64,21 +64,20 @@ template -typename std::enable_if::value>::type cutlassDistanceKernel( - const DataT* x, - const DataT* y, - const DataT* xn, - const DataT* yn, - IdxT m, - IdxT n, - IdxT k, - IdxT lda, - IdxT ldb, - IdxT ldd, - OutT* dOutput, - FinalLambda fin_op, - OpT distance_op, - cudaStream_t stream) +std::enable_if_t::value> cutlassDistanceKernel(const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + IdxT lda, + IdxT ldb, + IdxT ldd, + OutT* dOutput, + FinalLambda fin_op, + OpT distance_op, + cudaStream_t stream) { static_assert(!(std::is_same::value), "OutType bool is not supported use uint8_t instead"); diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh index 8524ce6fdf..e04b56ee8a 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -15,63 +15,74 @@ */ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include +/* This file has two responsibilities: + * + * 1. Dispatch to the correct implementation of a kernel based on the + * architecture of the device on which the kernel will be launched. For + * instance, the cosine distance has a CUTLASS-based implementation that can + * be used on SM80+ and the normal implementation that is used on older + * architectures. + * + * 2. Provide concise function templates that can be instantiated in + * src/distance/distance/specializations/detail/. Previously, + * raft::distance::detail::distance was instantiated. The function + * necessarily required a large set of include files, which slowed down the + * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions + * do not require as large an include files set, which speeds up the build. + */ + +#include // ops::has_cutlass_op +#include // dispatch_sm60 +#include // pairwise_matrix_params +#include // raft::util::arch::SM_* + +// NOTE: to minimize compile times, we do not include dispatch_sm80.cuh. +// Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). +// Therefore, it is the including file's responsibility to include the correct +// dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh +// and the specializations in src/distance/distance/specializations/detail/. namespace raft::distance::detail { +// This forward-declaration ensures that we do not need to include +// dispatch_sm80.cuh if we are not calling it in practice. This makes compiling +// all the non-CUTLASS based distance specializations faster. For CUTLASS-based +// distances, dispatch_sm80.cuh has to be included by the file including this +// file. template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } + typename SM_compat_t> +void pairwise_matrix_sm80_dispatch(OpT, + pairwise_matrix_params, + SM_compat_t, + cudaStream_t); +template +void pairwise_matrix_instantiation_point(OpT distance_op, + pairwise_matrix_params params, + cudaStream_t stream) +{ // On CUDA 12: // - always execute normal kernel // // On CUDA 11 and below: // - execute CUTLASS-based kernel on SM_80 and above // - execute normal kernel below SM_80 + namespace arch = raft::util::arch; constexpr bool is_ctk_12 = __CUDACC_VER_MAJOR__ == 12; constexpr bool cutlass_op_unavailable = !ops::has_cutlass_op(); if constexpr (is_ctk_12 || cutlass_op_unavailable) { // Always execute legacy kernels on CUDA 12 - auto any_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_future()); + auto any_range = arch::SM_range(arch::SM_min(), arch::SM_future()); pairwise_matrix_sm60_dispatch(distance_op, params, any_range, stream); } else { - auto cutlass_range = raft::arch::SM_range(raft::arch::SM_80(), raft::arch::SM_future()); - auto legacy_range = raft::arch::SM_range(raft::arch::SM_min(), raft::arch::SM_80()); + auto cutlass_range = arch::SM_range(arch::SM_80(), arch::SM_future()); + auto legacy_range = arch::SM_range(arch::SM_min(), arch::SM_80()); // Get pointer to SM60 kernel to determine the runtime architecture of the // current system. Other methods to determine the architecture (that do not @@ -79,7 +90,7 @@ void pairwise_matrix_dispatch(OpT distance_op, // https://github.com/NVIDIA/cub/issues/545 auto sm60_wrapper = pairwise_matrix_sm60_get_wrapper(distance_op, params, legacy_range); void* kernel_ptr = reinterpret_cast(sm60_wrapper.kernel_ptr); - auto runtime_arch = raft::arch::kernel_runtime_arch(kernel_ptr); + auto runtime_arch = arch::kernel_runtime_arch(kernel_ptr); if (cutlass_range.contains(runtime_arch)) { // If device is SM_80 or later, use CUTLASS-based kernel. @@ -92,4 +103,35 @@ void pairwise_matrix_dispatch(OpT distance_op, } } +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) +{ + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + pairwise_matrix_instantiation_point(distance_op, params, stream); +} + }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh index c1e4c08af4..f2b0e59822 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_layout.cuh @@ -15,10 +15,11 @@ */ #pragma once -#include "kernel_sm60.cuh" -#include -#include - +#include // std::min +#include // size_t +#include // RAFT_EXPECTS +#include // pairwise_matrix_params +#include // std::integral_constant namespace raft::distance::detail { /** @@ -99,15 +100,15 @@ auto dispatch_layout(bool row_major, int vec_len, F&& f) { if (row_major) { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::true_type(), vec_len_constant<4>()); + case 2: return f(std::true_type(), vec_len_constant<2>()); + default: return f(std::true_type(), vec_len_constant<1>()); } } else { switch (vec_len) { - case 4: return f(std::bool_constant(), vec_len_constant<4>()); - case 2: return f(std::bool_constant(), vec_len_constant<2>()); - default: return f(std::bool_constant(), vec_len_constant<1>()); + case 4: return f(std::false_type(), vec_len_constant<4>()); + case 2: return f(std::false_type(), vec_len_constant<2>()); + default: return f(std::false_type(), vec_len_constant<1>()); } } } diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh index 6e284007ea..2080fbe9cd 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch_sm60.cuh @@ -15,10 +15,10 @@ */ #pragma once -#include -#include -#include -#include +#include // std::min +#include // dispatch_layout +#include // pairwise_matrix_sm60_wrapper +#include // raft::linalg::Policy4x4 namespace raft::distance::detail { @@ -35,7 +35,11 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 { int vec_len = determine_vec_len(params); - return dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and returns + // the corresponding kernel wrapper. The wrapper contains the launch + // parameters of the kernel: a pointer to the kernel function, grid size, + // block size, and shared memory size. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -46,15 +50,19 @@ pairwise_matrix_sm60_wrapper pairwise_matrix_sm6 // Prevent double, vec_len=4 combination (this is not supported) constexpr int vec_len = std::min(vec_len_op, static_cast(16 / sizeof(DataT))); - typedef typename raft::linalg::Policy4x4::Policy RowPolicy; - typedef typename raft::linalg::Policy4x4::ColPolicy ColPolicy; - typedef typename std::conditional::type Policy; + using RowPolicy = typename raft::linalg::Policy4x4::Policy; + using ColPolicy = typename raft::linalg::Policy4x4::ColPolicy; + using Policy = typename std::conditional::type; auto wrapper = make_pairwise_matrix_sm60_wrapper(distance_op, params, sm_compat_range); return wrapper; - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + return dispatch_layout(params.is_row_major, vec_len, f); } template // std::min -#include -#include +#include // std::min +#include // cutlassDistanceKernel +#include // dispatch_layout namespace raft::distance::detail { @@ -34,7 +34,9 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, { int vec_len = determine_vec_len(params); - dispatch_layout(params.is_row_major, vec_len, [&](auto row_major, auto vec_len_aligned) { + // f takes compile-time constants row_major and vec_len aligned and runs the + // corresponding cutlass launch code. + auto f = [&](auto row_major, auto vec_len_aligned) { // row_major and vec_len are std::integral_constants of type bool and int // respectively. @@ -56,7 +58,11 @@ void pairwise_matrix_sm80_dispatch(OpT distance_op, params.fin_op, distance_op, stream); - }); + }; + + // Dispatch_layout calls f with appropriate compile time constants based on + // the runtime values of params.is_row_major and vec_len. + dispatch_layout(params.is_row_major, vec_len, f); } }; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh index 6e3ab7b26b..2d0a98862e 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/kernel_sm60.cuh @@ -15,11 +15,11 @@ */ #pragma once -#include -#include -#include -#include -#include +#include // assert +#include // raft::void_op +#include // PairwiseDistances +#include // pairwise_matrix_params +#include // raft::util::arch::SM_compute_arch namespace raft::distance::detail { @@ -36,43 +36,27 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void pairwise_matrix_kernel( { // Early exit to minimize the size of the kernel when it is not supposed to be compiled. constexpr SM_compat_t sm_compat_range{}; - if constexpr (!sm_compat_range.contains(raft::arch::SM_compute_arch())) { + if constexpr (!sm_compat_range.contains(raft::util::arch::SM_compute_arch())) { assert(false); return; } extern __shared__ char smem[]; - using AccT = typename OpT::AccT; - - // Wrap operator back into lambdas. This is temporary and should be removed. - // See: https://github.com/rapidsai/raft/issues/1323 - auto core_op = [distance_op] __device__(AccT & acc, DataT & x, DataT & y) { - distance_op.core(acc, x, y); - }; - auto epilog_op = [distance_op] __device__(AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - distance_op.template epilog(acc, regxn, regyn, gridStrideX, gridStrideY); - }; - + // The epilog is already provided by distance_op. Do not provide additional + // epilogs. + auto epilog_op = raft::void_op(); // No support for row_epilog_op. auto row_epilog_op = raft::void_op(); // Always write output constexpr bool write_out = true; constexpr bool use_norms = distance_op.use_norms; - PairwiseDistances -void pairwise_matrix(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) -{ - dim3 blk(Policy::Nthreads); - // Use .template to disambiguate (See: - // https://en.cppreference.com/w/cpp/language/dependent_name) - size_t smem_size = distance_op.template shared_mem_size(); - // Obtain function pointer to kernel - auto kernel = - pairwise_matrix_kernel; - dim3 grid = launchConfigGenerator(params.m, params.n, smem_size, kernel); - - kernel<<>>(distance_op, params); - RAFT_CUDA_TRY(cudaGetLastError()); -} - // The type of a pointer to the pairwise matrix kernel. The following template // arguments are type-erased: // @@ -181,9 +140,9 @@ pairwise_matrix_sm60_wrapper make_pairwise_matri SM_compat_t sm_compat_range) { dim3 block(Policy::Nthreads); - // Use .template to disambiguate (See: + // Use ::template to disambiguate (See: // https://en.cppreference.com/w/cpp/language/dependent_name) - int smem_size = distance_op.template shared_mem_size(); + int smem_size = OpT::template shared_mem_size(); // Obtain function pointer to kernel auto kernel = pairwise_matrix_kernel; diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..63ae6580b4 --- /dev/null +++ b/cpp/include/raft/distance/specializations/detail/00_write_template.py @@ -0,0 +1,148 @@ +#!/usr/bin/env python3 + +# This template manages all files in this directory, apart from +# inner_product.cuh and kernels.cuh. + + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +start_template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance::detail { + +""" + +extern_template = """ +extern template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); +""" + +end_template = """} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + + + + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + # cosine uses CUTLASS for SM80+ + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + # L2 expanded uses CUTLASS for SM80+ + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +for op_instance in op_instances: + path = fill_in("path_prefix.cuh", op_instance) + with open(path, "w") as f: + f.write(start_template) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "raft::identity_op", + } + + text = fill_in(extern_template, instance) + + f.write(text) + + f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh index badce715a5..276c85e5f6 100644 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ b/cpp/include/raft/distance/specializations/detail/canberra.cuh @@ -16,37 +16,25 @@ #pragma once -#include #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + float, + float, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::canberra_distance_op, + int, + double, + double, + raft::identity_op>(ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh index 013a0d43a3..f019f678df 100644 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ b/cpp/include/raft/distance/specializations/detail/correlation.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + float, + float, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::correlation_distance_op, + int, + double, + double, + raft::identity_op>(ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh index c88bd1b0f6..dcde4ec286 100644 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ b/cpp/include/raft/distance/specializations/detail/cosine.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::cosine_distance_op, + int, + double, + double, + raft::identity_op>(ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh index 3c5cad3315..1d6964fbce 100644 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + float, + float, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + +extern template void pairwise_matrix_instantiation_point< + ops::hamming_distance_op, + int, + double, + double, + raft::identity_op>(ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh index bf214c046f..f96a06f919 100644 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + float, + float, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::hellinger_distance_op, + int, + double, + double, + raft::identity_op>(ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh index 145834fb70..0b58646582 100644 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + float, + float, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::jensen_shannon_distance_op, + int, + double, + double, + raft::identity_op>(ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh index f0928916cd..5c164e0fd4 100644 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + raft::identity_op>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh index 23261a2571..870627d909 100644 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ b/cpp/include/raft/distance/specializations/detail/l1.cuh @@ -18,35 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point, + int, + double, + double, + raft::identity_op>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh index f953018b7d..ee3207bcce 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l2_exp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh deleted file mode 100644 index 9f5f6a3706..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_expanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh deleted file mode 100644 index 94531ddc33..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_sqrt_unexpanded.cuh +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh index 224b21fce8..1fbf57632b 100644 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l2_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh index 9a46d7b488..388d3bf439 100644 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ b/cpp/include/raft/distance/specializations/detail/l_inf.cuh @@ -18,35 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point, + int, + float, + float, + raft::identity_op>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::l_inf_distance_op, + int, + double, + double, + raft::identity_op>(ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh index e05ef02c42..d8e86ce6f2 100644 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh @@ -18,36 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + float, + float, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::lp_unexp_distance_op, + int, + double, + double, + raft::identity_op>(ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh index afc87997c0..4803fb8ab0 100644 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh @@ -18,37 +18,23 @@ #include -namespace raft { -namespace distance { -namespace detail { -extern template void -distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + float, + float, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +extern template void pairwise_matrix_instantiation_point< + ops::russel_rao_distance_op, + int, + double, + double, + raft::identity_op>(ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); +} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index 8daa398b49..a34f696e9e 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -27,8 +27,6 @@ #include #include #include -#include -#include #include #include #include diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k.cuh index ac1ba3dfa3..20c2fb119d 100644 --- a/cpp/include/raft/matrix/detail/select_k.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -84,7 +84,7 @@ void select_k(const T* in_val, in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { select::radix::select_k= 4 ? 11 : 8), 512>( - in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, true, stream, mr); } } diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 643a63d9db..7ac40ac0eb 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -16,11 +16,11 @@ #pragma once -#include - -#include #include #include +#include +#include +#include #include #include #include @@ -35,8 +35,8 @@ #include namespace raft::matrix::detail::select::radix { +namespace impl { -constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template @@ -51,13 +51,6 @@ _RAFT_HOST_DEVICE constexpr int calc_num_passes() return ceildiv(sizeof(T) * 8, BitsPerPass); } -// Minimum reasonable block size for the given radix size. -template -_RAFT_HOST_DEVICE constexpr int calc_min_block_size() -{ - return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); -} - /** * Bit 0 is the least significant (rightmost); * this implementation processes input from the most to the least significant bit. @@ -82,23 +75,43 @@ _RAFT_DEVICE constexpr unsigned calc_mask(int pass) } /** - * Use cub to twiddle bits - so that we can correctly compare bits of floating-point values as well + * Use CUB to twiddle bits - so that we can correctly compare bits of floating-point values as well * as of integers. */ template -_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool select_min) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); - if (greater) { bits = ~bits; } + if (!select_min) { bits = ~bits; } return bits; } +template +_RAFT_DEVICE T twiddle_out(typename cub::Traits::UnsignedBits bits, bool select_min) +{ + if (!select_min) { bits = ~bits; } + bits = cub::Traits::TwiddleOut(bits); + return reinterpret_cast(bits); +} + template -_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool select_min) +{ + static_assert(BitsPerPass <= sizeof(int) * 8 - 1, + "BitsPerPass is too large that the result type could not be int"); + return (twiddle_in(x, select_min) >> start_bit) & mask; +} + +template +_RAFT_HOST_DEVICE IdxT calc_buf_len(IdxT len) { - static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int - return (twiddle_in(x, greater) >> start_bit) & mask; + // When writing is skipped, only read `in`(type T). + // When writing is not skipped, read `in_buf`(T) and `in_idx_buf`(IdxT), and write `out_buf`(T) + // and `out_idx_buf`(IdxT). + // The ratio between these cases determines whether to skip writing and hence the buffer size. + constexpr float ratio = 2 + sizeof(IdxT) * 2.0 / sizeof(T); + return len / ratio; } /** @@ -111,17 +124,18 @@ _RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @tparam IdxT indexing type * @tparam Func void (T x, IdxT idx) * + * @param thread_rank rank of the calling thread among all participating threads + * @param num_threads number of the threads that participate in processing * @param in the input data * @param len the number of elements to read * @param f the lambda taking two arguments (T x, IdxT idx) */ template -_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process( + size_t thread_rank, size_t num_threads, const T* in, IdxT len, Func f) { - const IdxT stride = blockDim.x * gridDim.x; - const int tid = blockIdx.x * blockDim.x + threadIdx.x; if constexpr (sizeof(T) >= VECTORIZED_READ_SIZE || VECTORIZED_READ_SIZE % sizeof(T) != 0) { - for (IdxT i = tid; i < len; i += stride) { + for (IdxT i = thread_rank; i < len; i += num_threads) { f(in[i], i); } } else { @@ -134,8 +148,8 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) const IdxT skip_cnt_left = std::min((IdxT)(align_bytes::roundUp(in) - in), len); // The main loop: process all aligned data - for (IdxT i = tid * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; - i += stride * wide_t::Ratio) { + for (IdxT i = thread_rank * wide_t::Ratio + skip_cnt_left; i + wide_t::Ratio <= len; + i += num_threads * wide_t::Ratio) { wide.load(in, i); #pragma unroll for (int j = 0; j < wide_t::Ratio; ++j) { @@ -145,30 +159,55 @@ _RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) static_assert(WarpSize >= wide_t::Ratio); // Processes the skipped elements on the left - if (tid < skip_cnt_left) { f(in[tid], tid); } + if (thread_rank < skip_cnt_left) { f(in[thread_rank], thread_rank); } // Processes the skipped elements on the right const IdxT skip_cnt_right = align_elems::mod(len - skip_cnt_left); - const IdxT remain_i = len - skip_cnt_right + tid; + const IdxT remain_i = len - skip_cnt_right + thread_rank; if (remain_i < len) { f(in[remain_i], remain_i); } } } template -struct Counter { +struct alignas(128) Counter { + // We are processing the values in multiple passes, from most significant to least significant. In + // each pass, we keep the length of input (`len`) and the `k` of current pass, and update them at + // the end of the pass. IdxT k; IdxT len; + + // `previous_len` is the length of input in previous pass. Note that `previous_len` rather + // than `len` is used for the filtering step because filtering is indeed for previous pass (see + // comments before `radix_kernel`). IdxT previous_len; - int bucket; - IdxT filter_cnt; - unsigned int finished_block_cnt; - IdxT out_cnt; - IdxT out_back_cnt; + // We determine the bits of the k_th value inside the mask processed by the pass. The + // already known bits are stored in `kth_value_bits`. It's used to discriminate a element is a + // result (written to `out`), a candidate for next pass (written to `out_buf`), or not useful + // (discarded). The bits that are not yet processed do not matter for this purpose. + typename cub::Traits::UnsignedBits kth_value_bits; + + // Record how many elements have passed filtering. It's used to determine the position in the + // `out_buf` where an element should be written. + alignas(128) IdxT filter_cnt; + + // For a row inside a batch, we may launch multiple thread blocks. This counter is used to + // determine if the current block is the last running block. If so, this block will execute scan() + // and choose_bucket(). + alignas(128) unsigned int finished_block_cnt; + + // Record how many elements have been written to the front of `out`. Elements less (if + // select_min==true) than the k-th value are written from front to back. + alignas(128) IdxT out_cnt; + + // Record how many elements have been written to the back of `out`. Elements equal to the k-th + // value are written from back to front. We need to keep count of them separately because the + // number of elements that <= the k-th value might exceed k. + alignas(128) IdxT out_back_cnt; }; /** - * Fused filtering of the current phase and building histogram for the next phase - * (see steps 4-1 in `radix_kernel` description). + * Fused filtering of the current pass and building histogram for the next pass + * (see steps 4 & 1 in `radix_kernel` description). */ template _RAFT_DEVICE void filter_and_histogram(const T* in_buf, @@ -177,12 +216,12 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, IdxT* out_idx_buf, T* out, IdxT* out_idx, - IdxT len, + IdxT previous_len, Counter* counter, IdxT* histogram, - bool greater, + bool select_min, int pass, - int k) + bool early_stop) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -198,19 +237,20 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, // Passed to vectorized_process, this function executes in all blocks in parallel, // i.e. the work is split along the input (both, in batches and chunks of a single row). // Later, the histograms are merged using atomicAdd. - auto f = [greater, start_bit, mask](T value, IdxT) { - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); + auto f = [select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); }; - vectorized_process(in_buf, len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } else { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - IdxT& filter_cnt = counter->filter_cnt; - IdxT& out_cnt = counter->out_cnt; - const IdxT counter_len = counter->len; + IdxT* p_filter_cnt = &counter->filter_cnt; + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; const int previous_start_bit = calc_start_bit(pass - 1); - const unsigned previous_mask = calc_mask(pass - 1); // See the remark above on the distributed execution of `f` using vectorized_process. auto f = [in_idx_buf, @@ -218,38 +258,50 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, out_idx_buf, out, out_idx, - greater, - k, + select_min, start_bit, mask, previous_start_bit, - previous_mask, - want_bucket, - &filter_cnt, - &out_cnt, - counter_len](T value, IdxT i) { - int prev_bucket = - calc_bucket(value, previous_start_bit, previous_mask, greater); - if (prev_bucket == want_bucket) { - IdxT pos = atomicAdd(&filter_cnt, IdxT(1)); - out_buf[pos] = value; - if (out_idx_buf) { out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; } - int bucket = calc_bucket(value, start_bit, mask, greater); - atomicAdd(histogram_smem + bucket, IdxT(1)); - - if (counter_len == 1) { - out[k - 1] = value; - out_idx[k - 1] = in_idx_buf ? in_idx_buf[i] : i; + kth_value_bits, + p_filter_cnt, + p_out_cnt, + early_stop](T value, IdxT i) { + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { + if (early_stop) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else { + if (out_buf) { + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram_smem + bucket, static_cast(1)); } - } else if (prev_bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); + } + // the condition `(out_buf || early_stop)` is a little tricky: + // If we skip writing to `out_buf` (when `out_buf` is nullptr), we should skip writing to + // `out` too. So we won't write the same value to `out` multiple times in different passes. + // And if we keep skipping the writing, values will be written in `last_filter_kernel()` at + // last. But when `early_stop` is true, we need to write to `out` since it's the last chance. + else if ((out_buf || early_stop) && previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); out[pos] = value; out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } }; - - vectorized_process(in_buf, previous_len, f); + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); } + if (early_stop) { return; } __syncthreads(); // merge histograms produced by individual blocks @@ -259,69 +311,184 @@ _RAFT_DEVICE void filter_and_histogram(const T* in_buf, } /** - * Replace a part of the histogram with its own prefix sum, starting from the `start` and adding - * `current` to each entry of the result. + * Replace histogram with its own prefix sum * (step 2 in `radix_kernel` description) */ template -_RAFT_DEVICE void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram) { - typedef cub::BlockScan BlockScan; - __shared__ typename BlockScan::TempStorage temp_storage; + constexpr int num_buckets = calc_num_buckets(); + if constexpr (num_buckets >= BlockSize) { + static_assert(num_buckets % BlockSize == 0); + constexpr int items_per_thread = num_buckets / BlockSize; + typedef cub::BlockLoad BlockLoad; + typedef cub::BlockStore + BlockStore; + typedef cub::BlockScan BlockScan; - IdxT thread_data = 0; - int index = start + threadIdx.x; - if (index < num_buckets) { thread_data = histogram[index]; } + __shared__ union { + typename BlockLoad::TempStorage load; + typename BlockScan::TempStorage scan; + typename BlockStore::TempStorage store; + } temp_storage; + IdxT thread_data[items_per_thread]; - BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); - __syncthreads(); - if (index < num_buckets) { histogram[index] = thread_data + current; } - __syncthreads(); // This sync is necessary, as the content of histogram needs - // to be read after + BlockLoad(temp_storage.load).Load(histogram, thread_data); + __syncthreads(); + + BlockScan(temp_storage.scan).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + BlockStore(temp_storage.store).Store(histogram, thread_data); + } else { + typedef cub::BlockScan BlockScan; + __shared__ typename BlockScan::TempStorage temp_storage; + + IdxT thread_data = 0; + if (threadIdx.x < num_buckets) { thread_data = histogram[threadIdx.x]; } + + BlockScan(temp_storage).InclusiveSum(thread_data, thread_data); + __syncthreads(); + + if (threadIdx.x < num_buckets) { histogram[threadIdx.x] = thread_data; } + } } /** * Calculate in which bucket the k-th value will fall - * (steps 2-3 in `radix_kernel` description) + * (steps 3 in `radix_kernel` description) */ -template -_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +template +_RAFT_DEVICE void choose_bucket(Counter* counter, + const IdxT* histogram, + const IdxT k, + const int pass) { constexpr int num_buckets = calc_num_buckets(); - int index = threadIdx.x; - IdxT last_prefix_sum = 0; - int num_pass = 1; - if constexpr (num_buckets >= BlockSize) { - static_assert(num_buckets % BlockSize == 0); - num_pass = num_buckets / BlockSize; + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + IdxT prev = (i == 0) ? 0 : histogram[i - 1]; + IdxT cur = histogram[i]; + + // one and only one thread will satisfy this condition, so counter is written by only one thread + if (prev < k && cur >= k) { + counter->k = k - prev; // how many values still are there to find + counter->len = cur - prev; // number of values in next pass + typename cub::Traits::UnsignedBits bucket = i; + int start_bit = calc_start_bit(pass); + counter->kth_value_bits |= bucket << start_bit; + } } +} - for (int i = 0; i < num_pass && (last_prefix_sum < k); i++) { - // Turn the i-th chunk of the histogram into its prefix sum. - scan(histogram, i * BlockSize, num_buckets, last_prefix_sum); - if (index < num_buckets) { - // Number of values in the previous `index-1` buckets (see the `scan` op above) - IdxT prev = (index == 0) ? 0 : histogram[index - 1]; - // Number of values in `index` buckets - IdxT cur = histogram[index]; - - // one and only one thread will satisfy this condition, so only write once - if (prev < k && cur >= k) { - counter->k = k - prev; // how many values still are there to find - counter->previous_len = counter->len; - counter->len = cur - prev; // number of values in `index` bucket - counter->bucket = index; +// For one-block version, last_filter() could be called when pass < num_passes - 1. +// So `pass` could not be constexpr +template +_RAFT_DEVICE void last_filter(const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT current_len, + IdxT k, + Counter* counter, + const bool select_min, + const int pass) +{ + const auto kth_value_bits = counter->kth_value_bits; + const int start_bit = calc_start_bit(pass); + + // changed in choose_bucket(); need to reload + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + for (IdxT i = threadIdx.x; i < current_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + // For one-block version, `in_idx_buf` could be nullptr at pass 0. + // For non one-block version, if writing has been skipped, `in_idx_buf` could be nullptr if + // `in_buf` is `in` + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } } - index += BlockSize; - // this will break the loop when the counter is set (cur >= k), because last_prefix_sum >= cur - last_prefix_sum = histogram[(i + 1) * BlockSize - 1]; } } +template +__global__ void last_filter_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + IdxT k, + Counter* counters, + const bool select_min) +{ + const size_t batch_id = blockIdx.y; // size_t to avoid multiplication overflow + + Counter* counter = counters + batch_id; + IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + const IdxT buf_len = calc_buf_len(len); + if (previous_len > buf_len || in_buf == in) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + out += batch_id * k; + out_idx += batch_id * k; + + constexpr int pass = calc_num_passes() - 1; + constexpr int start_bit = calc_start_bit(pass); + + const auto kth_value_bits = counter->kth_value_bits; + const IdxT needed_num_of_kth = counter->k; + IdxT* p_out_cnt = &counter->out_cnt; + IdxT* p_out_back_cnt = &counter->out_back_cnt; + + auto f = [k, + select_min, + kth_value_bits, + needed_num_of_kth, + p_out_cnt, + p_out_back_cnt, + in_idx_buf, + out, + out_idx](T value, IdxT i) { + const auto bits = (twiddle_in(value, select_min) >> start_bit) << start_bit; + if (bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } else if (bits == kth_value_bits) { + IdxT back_pos = atomicAdd(p_out_back_cnt, static_cast(1)); + if (back_pos < needed_num_of_kth) { + IdxT pos = k - 1 - back_pos; + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; + } + } + }; + + vectorized_process(static_cast(blockIdx.x) * blockDim.x + threadIdx.x, + static_cast(blockDim.x) * gridDim.x, + in_buf, + previous_len, + f); +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -350,35 +517,79 @@ _RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, cons * * In the implementation, the filtering step is delayed to the next pass so the filtering and * histogram computation are fused. In this way, inputs are read once rather than twice. + * + * During the filtering step, we won't write candidates (elements in bucket j) to `out_buf` if the + * number of candidates is larger than the length of `out_buf` (this could happen when the leading + * bits of input values are almost the same). And then in the next pass, inputs are read from `in` + * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and + * their indices. */ -template -__global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - Counter* counters, - IdxT* histograms, - const IdxT len, - const int k, - const bool greater, - const int pass) +template +__global__ void radix_kernel(const T* in, + const IdxT* in_idx, + const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counters, + IdxT* histograms, + const IdxT len, + const IdxT k, + const bool select_min, + const int pass) { - __shared__ bool isLastBlockDone; + const size_t batch_id = blockIdx.y; + auto counter = counters + batch_id; + IdxT current_k; + IdxT previous_len; + IdxT current_len; + if (pass == 0) { + current_k = k; + previous_len = len; + // Need to do this so setting counter->previous_len for the next pass is correct. + // This value is meaningless for pass 0, but it's fine because pass 0 won't be the + // last pass in this implementation so pass 0 won't hit the "if (pass == + // num_passes - 1)" branch. + // Maybe it's better to reload counter->previous_len and use it rather than + // current_len in last_filter() + current_len = len; + } else { + current_k = counter->k; + current_len = counter->len; + previous_len = counter->previous_len; + } + if (current_len == 0) { return; } - constexpr int num_buckets = calc_num_buckets(); - constexpr int num_passes = calc_num_passes(); - const int batch_id = blockIdx.y; - in_buf += batch_id * len; - out_buf += batch_id * len; + // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle + // correctly the case that pass=0 and early_stop=true. However, this special case of k=len is + // handled in other way in select_k() so such case is not possible here. + const bool early_stop = (current_len == current_k); + const IdxT buf_len = calc_buf_len(len); + + // "previous_len > buf_len" means previous pass skips writing buffer + if (pass == 0 || pass == 1 || previous_len > buf_len) { + in_buf = in + batch_id * len; + in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; + previous_len = len; + } else { + in_buf += batch_id * buf_len; + in_idx_buf += batch_id * buf_len; + } + // "current_len > buf_len" means current pass will skip writing buffer + if (pass == 0 || current_len > buf_len) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else { + out_buf += batch_id * buf_len; + out_idx_buf += batch_id * buf_len; + } out += batch_id * k; out_idx += batch_id * k; - if (in_idx_buf) { in_idx_buf += batch_id * len; } - if (out_idx_buf) { out_idx_buf += batch_id * len; } - auto counter = counters + batch_id; - auto histogram = histograms + batch_id * num_buckets; + constexpr int num_buckets = calc_num_buckets(); + auto histogram = histograms + batch_id * num_buckets; filter_and_histogram(in_buf, in_idx_buf, @@ -386,126 +597,468 @@ __global__ void __launch_bounds__(BlockSize) radix_kernel(const T* in_buf, out_idx_buf, out, out_idx, - len, + previous_len, counter, histogram, - greater, + select_min, pass, - k); + early_stop); __threadfence(); + bool isLastBlock = false; if (threadIdx.x == 0) { unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); - isLastBlockDone = (finished == (gridDim.x - 1)); + isLastBlock = (finished == (gridDim.x - 1)); } - // Synchronize to make sure that each thread reads the correct value of - // isLastBlockDone. - __syncthreads(); - if (isLastBlockDone) { - if (counter->len == 1 && threadIdx.x == 0) { - counter->previous_len = 0; - counter->len = 0; - } - // init counter, other members of counter is initialized with 0 by - // cudaMemset() - if (pass == 0 && threadIdx.x == 0) { - counter->k = k; - counter->len = len; - counter->out_back_cnt = 0; + if (__syncthreads_or(isLastBlock)) { + if (early_stop) { + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len + counter->previous_len = 0; + counter->len = 0; + } + return; } + + scan(histogram); __syncthreads(); + choose_bucket(counter, histogram, current_k, pass); + __syncthreads(); + + constexpr int num_passes = calc_num_passes(); + // reset for next pass + if (pass != num_passes - 1) { + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + } + if (threadIdx.x == 0) { + // `last_filter_kernel()` requires setting previous_len even in the last pass + counter->previous_len = current_len; + // not necessary for the last pass, but put it here anyway + counter->filter_cnt = 0; + } + + if constexpr (fused_last_filter) { + if (pass == num_passes - 1) { + last_filter(out_buf ? out_buf : in_buf, + out_idx_buf ? out_idx_buf : in_idx_buf, + out, + out_idx, + out_buf ? current_len : len, + k, + counter, + select_min, + pass); + } + } + } +} + +template +int calc_chunk_size(int batch_size, IdxT len, int sm_cnt, Kernel kernel) +{ + int active_blocks; + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&active_blocks, kernel, BlockSize, 0)); + + constexpr int items_per_thread = 32; + constexpr int num_waves = 10; + int chunk_size = + std::max(1, num_waves * sm_cnt * active_blocks * BlockSize * items_per_thread / len); + return std::min(chunk_size, batch_size); +} - IdxT ori_k = counter->k; +template +unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) +{ + static_assert(VECTORIZED_READ_SIZE / sizeof(T) >= 1); + + int active_blocks; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &active_blocks, radix_kernel, BlockSize, 0)); + active_blocks *= sm_cnt; - if (counter->len > 0) { - choose_bucket(counter, histogram, ori_k); + IdxT best_num_blocks = 0; + float best_tail_wave_penalty = 1.0f; + const IdxT max_num_blocks = ceildiv(len, VECTORIZED_READ_SIZE / sizeof(T) * BlockSize); + for (int num_waves = 1;; ++num_waves) { + IdxT num_blocks = std::min( + max_num_blocks, static_cast(std::max(num_waves * active_blocks / batch_size, 1))); + IdxT items_per_thread = ceildiv(len, num_blocks * BlockSize); + items_per_thread = alignTo(items_per_thread, VECTORIZED_READ_SIZE / sizeof(T)); + num_blocks = ceildiv(len, items_per_thread * BlockSize); + float actual_num_waves = static_cast(num_blocks) * batch_size / active_blocks; + float tail_wave_penalty = + (ceilf(actual_num_waves) - actual_num_waves) / ceilf(actual_num_waves); + + // 0.15 is determined experimentally. It also ensures breaking the loop early, + // e.g. when num_waves > 7, tail_wave_penalty will always <0.15 + if (tail_wave_penalty < 0.15) { + best_num_blocks = num_blocks; + break; + } else if (tail_wave_penalty < best_tail_wave_penalty) { + best_num_blocks = num_blocks; + best_tail_wave_penalty = tail_wave_penalty; } - __syncthreads(); - if (pass == num_passes - 1) { - const IdxT previous_len = counter->previous_len; - const int want_bucket = counter->bucket; - int start_bit = calc_start_bit(pass); - unsigned mask = calc_mask(pass); - - // radix topk - IdxT& out_cnt = counter->out_cnt; - for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { - const T value = out_buf[i]; - int bucket = calc_bucket(value, start_bit, mask, greater); - if (bucket < want_bucket) { - IdxT pos = atomicAdd(&out_cnt, IdxT(1)); - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } else if (bucket == want_bucket) { - IdxT needed_num_of_kth = counter->k; - IdxT back_pos = atomicAdd(&(counter->out_back_cnt), IdxT(1)); - if (back_pos < needed_num_of_kth) { - IdxT pos = k - 1 - back_pos; - out[pos] = value; - out_idx[pos] = out_idx_buf[i]; - } - } + if (num_blocks == max_num_blocks) { break; } + } + return best_num_blocks; +} + +template +_RAFT_HOST_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else if (pass % 2 == 0) { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } else { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } +} + +template +void radix_topk(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + bool fused_last_filter, + unsigned grid_dim, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + // TODO: is it possible to relax this restriction? + static_assert(calc_num_passes() > 1); + constexpr int num_buckets = calc_num_buckets(); + + auto kernel = radix_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + if (max_chunk_size != static_cast(batch_size)) { + grid_dim = calc_grid_dim(max_chunk_size, len, sm_cnt); + } + const IdxT buf_len = calc_buf_len(len); + + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf + 256 * 6; // might need extra memory for alignment + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector> counters(max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); + rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + RAFT_CUDA_TRY( + cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); + + const T* chunk_in = in + offset * len; + const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + dim3 blocks(grid_dim, chunk_size); + constexpr int num_passes = calc_num_passes(); + + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(chunk_in, + chunk_in_idx, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data(), + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + + if (fused_last_filter && pass == num_passes - 1) { + kernel = radix_kernel; } - __syncthreads(); - } else { - // reset for next pass - for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { - histogram[i] = 0; + + kernel<<>>(chunk_in, + chunk_in_idx, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + counters.data(), + histograms.data(), + len, + k, + select_min, + pass); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + + if (!fused_last_filter) { + last_filter_kernel<<>>(chunk_in, + chunk_in_idx, + out_buf, + out_idx_buf, + chunk_out, + chunk_out_idx, + len, + k, + counters.data(), + select_min); + RAFT_CUDA_TRY(cudaPeekAtLastError()); + } + } +} + +// The following a few functions are for the one-block version, which uses single thread block for +// each row of a batch. +template +_RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + Counter* counter, + IdxT* histogram, + bool select_min, + int pass) +{ + constexpr int num_buckets = calc_num_buckets(); + for (int i = threadIdx.x; i < num_buckets; i += blockDim.x) { + histogram[i] = 0; + } + IdxT* p_filter_cnt = &counter->filter_cnt; + if (threadIdx.x == 0) { *p_filter_cnt = 0; } + __syncthreads(); + + const int start_bit = calc_start_bit(pass); + const unsigned mask = calc_mask(pass); + const IdxT previous_len = counter->previous_len; + + if (pass == 0) { + auto f = [histogram, select_min, start_bit, mask](T value, IdxT) { + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + }; + vectorized_process(threadIdx.x, blockDim.x, in_buf, previous_len, f); + } else { + // not use vectorized_process here because it increases #registers a lot + IdxT* p_out_cnt = &counter->out_cnt; + const auto kth_value_bits = counter->kth_value_bits; + const int previous_start_bit = calc_start_bit(pass - 1); + + for (IdxT i = threadIdx.x; i < previous_len; i += blockDim.x) { + const T value = in_buf[i]; + const auto previous_bits = (twiddle_in(value, select_min) >> previous_start_bit) + << previous_start_bit; + if (previous_bits == kth_value_bits) { +#if CUDART_VERSION < 12000 + // Avoiding potential compiler bug in CUDA 11 + volatile +#endif + IdxT pos = atomicAdd(p_filter_cnt, static_cast(1)); + out_buf[pos] = value; + out_idx_buf[pos] = in_idx_buf ? in_idx_buf[i] : i; + + int bucket = calc_bucket(value, start_bit, mask, select_min); + atomicAdd(histogram + bucket, static_cast(1)); + } else if (previous_bits < kth_value_bits) { + IdxT pos = atomicAdd(p_out_cnt, static_cast(1)); + out[pos] = value; + out_idx[pos] = in_idx_buf ? in_idx_buf[i] : i; } - if (threadIdx.x == 0) { counter->filter_cnt = 0; } } } } -/** - * Calculate the minimal batch size, such that GPU is still fully occupied. - */ template -inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) +__global__ void radix_topk_one_block_kernel(const T* in, + const IdxT* in_idx, + const IdxT len, + const IdxT k, + T* out, + IdxT* out_idx, + const bool select_min, + T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2) { - int dev_id, sm_count, occupancy, max_grid_dim_y; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_count, cudaDevAttrMultiProcessorCount, dev_id)); - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&max_grid_dim_y, cudaDevAttrMaxGridDimY, dev_id)); - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &occupancy, radix_kernel, BlockSize, 0)); - - // number of block we'd use if the batch size is enough to occupy the gpu in any case - size_t blocks_per_row = ceildiv(len, BlockSize * ITEM_PER_THREAD); - - // fully occupy GPU - size_t opt_batch_size = ceildiv(sm_count * occupancy, blocks_per_row); - // round it up to the closest pow-of-two for better data alignment - opt_batch_size = isPo2(opt_batch_size) ? opt_batch_size : (1 << (log2(opt_batch_size) + 1)); - // Take a max possible pow-of-two grid_dim_y - max_grid_dim_y = isPo2(max_grid_dim_y) ? max_grid_dim_y : (1 << log2(max_grid_dim_y)); - // If the optimal batch size is very small compared to the requested batch size, we know - // the extra required memory is not significant and we can increase the batch size for - // better occupancy when the grid size is not multiple of the SM count. - // Also don't split the batch size when there is not much work overall. - const size_t safe_enlarge_factor = 9; - const size_t min_grid_size = 1024; - while ((opt_batch_size << safe_enlarge_factor) < req_batch_size || - blocks_per_row * opt_batch_size < min_grid_size) { - opt_batch_size <<= 1; + constexpr int num_buckets = calc_num_buckets(); + __shared__ Counter counter; + __shared__ IdxT histogram[num_buckets]; + + if (threadIdx.x == 0) { + counter.k = k; + counter.len = len; + counter.previous_len = len; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; } + __syncthreads(); + + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + in += batch_id * len; + if (in_idx) { in_idx += batch_id * len; } + out += batch_id * k; + out_idx += batch_id * k; + buf1 += batch_id * len; + idx_buf1 += batch_id * len; + buf2 += batch_id * len; + idx_buf2 += batch_id * len; + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + T* out_buf = nullptr; + IdxT* out_idx_buf = nullptr; + + constexpr int num_passes = calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers( + in, in_idx, buf1, idx_buf1, buf2, idx_buf2, pass, in_buf, in_idx_buf, out_buf, out_idx_buf); + + IdxT current_len = counter.len; + IdxT current_k = counter.k; + + filter_and_histogram_for_one_block(in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + out, + out_idx, + &counter, + histogram, + select_min, + pass); + __syncthreads(); + + scan(histogram); + __syncthreads(); + + choose_bucket(&counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } + __syncthreads(); - // Do not exceed the max grid size. - opt_batch_size = std::min(opt_batch_size, size_t(max_grid_dim_y)); - // Don't do more work than needed - opt_batch_size = std::min(opt_batch_size, req_batch_size); - // Let more blocks share one row if the required batch size is too small. - while (opt_batch_size * blocks_per_row < size_t(sm_count * occupancy) && - // Ensure we still can read data somewhat efficiently - len * sizeof(T) > 2 * VECTORIZED_READ_SIZE * BlockSize * blocks_per_row) { - blocks_per_row <<= 1; + if (counter.len == counter.k || pass == num_passes - 1) { + last_filter(pass == 0 ? in : out_buf, + pass == 0 ? in_idx : out_idx_buf, + out, + out_idx, + current_len, + k, + &counter, + select_min, + pass); + break; + } } +} - return dim3(blocks_per_row, opt_batch_size); +// radix_topk() might use multiple thread blocks for one row of a batch. In contrast, the following +// one-block version uses single thread block for one row of a batch, so intermediate data, like +// counters and global histograms, can be kept in shared memory and cheap sync operations can be +// used. It's used when len is relatively small or when the number of blocks per row calculated by +// `calc_grid_dim()` is 1. +template +void radix_topk_one_block(const T* in, + const IdxT* in_idx, + int batch_size, + IdxT len, + IdxT k, + T* out, + IdxT* out_idx, + bool select_min, + int sm_cnt, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr) +{ + static_assert(calc_num_passes() > 1); + + auto kernel = radix_topk_one_block_kernel; + const size_t max_chunk_size = + calc_chunk_size(batch_size, len, sm_cnt, kernel); + + auto pool_guard = + raft::get_pool_memory_resource(mr, + max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)) + + 256 * 4 // might need extra memory for alignment + ); + if (pool_guard) { + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + rmm::device_uvector buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); + rmm::device_uvector buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + + for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { + int chunk_size = std::min(max_chunk_size, batch_size - offset); + kernel<<>>(in + offset * len, + in_idx ? (in_idx + offset * len) : nullptr, + len, + k, + out + offset * k, + out_idx + offset * k, + select_min, + buf1.data(), + idx_buf1.data(), + buf2.data(), + idx_buf2.data()); + } } +} // namespace impl + /** * Select k smallest or largest key/values from each row in the input data. * @@ -546,6 +1099,12 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * the payload selected together with `out`. * @param select_min * whether to select k smallest (true) or largest (false) keys. + * @param fused_last_filter + * when it's true, the last filter is fused into the kernel in the last pass and only one thread + * block will do the filtering; when false, a standalone filter kernel with multiple thread + * blocks is called. The later case is preferable when leading bits of input data are almost the + * same. That is, when the value range of input data is narrow. In such case, there could be a + * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param stream * @param mr an optional memory resource to use across the calls (you can provide a large enough * memory pool here to avoid memory allocations within the call). @@ -553,109 +1112,65 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) template void select_k(const T* in, const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, + int batch_size, + IdxT len, + IdxT k, T* out, IdxT* out_idx, bool select_min, + bool fused_last_filter, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr = nullptr) { - // reduce the block size if the input length is too small. - if constexpr (BlockSize > calc_min_block_size()) { - if (BlockSize * ITEM_PER_THREAD > len) { - return select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + if (k == len) { + RAFT_CUDA_TRY( + cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + if (in_idx) { + RAFT_CUDA_TRY(cudaMemcpyAsync( + out_idx, in_idx, sizeof(IdxT) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); + } else { + auto out_idx_view = + raft::make_device_vector_view(out_idx, static_cast(len) * batch_size); + raft::device_resources handle(stream); + raft::linalg::map_offset(handle, out_idx_view, raft::mod_const_op(len)); } + return; } - // TODO: is it possible to relax this restriction? - static_assert(calc_num_passes() > 1); - constexpr int num_buckets = calc_num_buckets(); - - dim3 blocks = get_optimal_grid_size(batch_size, len); - size_t max_chunk_size = blocks.y; - - size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); - size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); - size_t mem_req = req_aux + req_buf; - size_t mem_free, mem_total; - RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); - std::optional managed_memory; - rmm::mr::device_memory_resource* mr_buf = nullptr; - if (mem_req > mem_free) { - // if there's not enough memory for buffers on the device, resort to the managed memory. - mem_req = req_aux; - managed_memory.emplace(); - mr_buf = &managed_memory.value(); - } - - auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); - if (pool_guard) { - RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", - pool_guard->pool_size()); + // TODO: use device_resources::get_device_properties() instead; should change it when we refactor + // resource management + int sm_cnt; + { + int dev; + RAFT_CUDA_TRY(cudaGetDevice(&dev)); + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&sm_cnt, cudaDevAttrMultiProcessorCount, dev)); } - if (mr_buf == nullptr) { mr_buf = mr; } - - rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); - rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); - for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { - blocks.y = std::min(max_chunk_size, batch_size - offset); + constexpr int items_per_thread = 32; - RAFT_CUDA_TRY( - cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); - RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; - - constexpr int num_passes = calc_num_passes(); - - for (int pass = 0; pass < num_passes; ++pass) { - if (pass == 0) { - in_buf = in + offset * len; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in + offset * len; - in_idx_buf = in_idx ? in_idx + offset * len : nullptr; - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } else if (pass % 2 == 0) { - in_buf = buf1.data(); - in_idx_buf = idx_buf1.data(); - out_buf = buf2.data(); - out_idx_buf = idx_buf2.data(); - } else { - in_buf = buf2.data(); - in_idx_buf = idx_buf2.data(); - out_buf = buf1.data(); - out_idx_buf = idx_buf1.data(); - } - - radix_kernel - <<>>(in_buf, - in_idx_buf, - out_buf, - out_idx_buf, - out + offset * k, - out_idx + offset * k, - counters.data(), - histograms.data(), - len, - k, - !select_min, - pass); - RAFT_CUDA_TRY(cudaPeekAtLastError()); + if (len <= BlockSize * items_per_thread) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + unsigned grid_dim = + impl::calc_grid_dim(batch_size, len, sm_cnt); + if (grid_dim == 1) { + impl::radix_topk_one_block( + in, in_idx, batch_size, len, k, out, out_idx, select_min, sm_cnt, stream, mr); + } else { + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh index 4891cc5f8d..dac1a29c7f 100644 --- a/cpp/include/raft/neighbors/brute_force.cuh +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -122,9 +122,8 @@ inline void knn_merge_parts( * * raft::raft::device_resources handle; * ... - * int k = 10; * auto metric = raft::distance::DistanceType::L2SqrtExpanded; - * brute_force::knn(handle, index, search, indices, distances, k, metric); + * brute_force::knn(handle, index, search, indices, distances, metric); * @endcode * * @param[in] handle: the cuml handle to use @@ -132,28 +131,31 @@ inline void knn_merge_parts( * @param[in] search: matrix (size n*d) to be used for searching the index * @param[out] indices: matrix (size n*k) to store output knn indices * @param[out] distances: matrix (size n*k) to store the output knn distance - * @param[in] k: the number of nearest neighbors to return * @param[in] metric: distance metric to use. Euclidean (L2) is used by default * @param[in] metric_arg: the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. * @param[in] global_id_offset: optional starting global id mapping for the local partition * (assumes the index contains contiguous ids in the global id space) + * @param[in] distance_epilogue: optional epilogue function to run after computing distances. This + function takes a triple of the (value, rowid, colid) for each + element in the pairwise distances and returns a transformed value + back. */ template + typename search_layout, + typename epilogue_op = raft::identity_op> void knn(raft::device_resources const& handle, std::vector> index, raft::device_matrix_view search, raft::device_matrix_view indices, raft::device_matrix_view distances, - value_int k, distance::DistanceType metric = distance::DistanceType::L2Unexpanded, std::optional metric_arg = std::make_optional(2.0f), - std::optional global_id_offset = std::nullopt) + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) { RAFT_EXPECTS(index[0].extent(1) == search.extent(1), "Number of dimensions for both index and search matrices must be equal"); @@ -161,15 +163,14 @@ void knn(raft::device_resources const& handle, RAFT_EXPECTS(indices.extent(0) == distances.extent(0) && distances.extent(0) == search.extent(0), "Number of rows in output indices and distances matrices must equal number of rows " "in search matrix."); - RAFT_EXPECTS( - indices.extent(1) == distances.extent(1) && distances.extent(1) == static_cast(k), - "Number of columns in output indices and distances matrices must be equal to k"); + RAFT_EXPECTS(indices.extent(1) == distances.extent(1) && distances.extent(1), + "Number of columns in output indices and distances matrices must the same"); bool rowMajorIndex = std::is_same_v; bool rowMajorQuery = std::is_same_v; std::vector inputs; - std::vector sizes; + std::vector sizes; for (std::size_t i = 0; i < index.size(); ++i) { inputs.push_back(const_cast(index[i].data_handle())); sizes.push_back(index[i].extent(0)); @@ -183,18 +184,19 @@ void knn(raft::device_resources const& handle, raft::neighbors::detail::brute_force_knn_impl(handle, inputs, sizes, - static_cast(index[0].extent(1)), + index[0].extent(1), // TODO: This is unfortunate. Need to fix. const_cast(search.data_handle()), - static_cast(search.extent(0)), + search.extent(0), indices.data_handle(), distances.data_handle(), - k, + indices.extent(1), rowMajorIndex, rowMajorQuery, trans_arg, metric, - metric_arg.value_or(2.0f)); + metric_arg.value_or(2.0f), + distance_epilogue); } /** diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index f657070df4..e6533eaf51 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -1065,6 +1065,14 @@ void ivfflat_interleaved_scan(const index& index, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // related function calls after editing this function definition. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. + const int capacity = bound_by_power_of_two(k); select_interleaved_scan_kernel::run(capacity, index.veclen(), diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 875fc3b37c..a776ce2586 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -47,7 +47,9 @@ using namespace raft::spatial::knn; * Calculates brute force knn, using a fixed memory budget * by tiling over both the rows and columns of pairwise_distances */ -template +template void tiled_brute_force_knn(const raft::device_resources& handle, const ElementType* search, // size (m ,d) const ElementType* index, // size (n ,d) @@ -58,9 +60,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, ElementType* distances, // size (m, k) IndexType* indices, // size (m, k) raft::distance::DistanceType metric, - float metric_arg = 0.0, - size_t max_row_tile_size = 0, - size_t max_col_tile_size = 0) + float metric_arg = 2.0, + size_t max_row_tile_size = 0, + size_t max_col_tile_size = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { // Figure out the number of rows/cols to tile for size_t tile_rows = 0; @@ -152,25 +155,41 @@ void tiled_brute_force_knn(const raft::device_resources& handle, metric_arg); if (metric == raft::distance::DistanceType::L2Expanded || metric == raft::distance::DistanceType::L2SqrtExpanded) { - auto row_norms = search_norms.data() + i; - auto col_norms = index_norms.data() + j; + auto row_norms = search_norms.data(); + auto col_norms = index_norms.data(); auto dist = temp_distances.data(); raft::linalg::map_offset( handle, raft::make_device_vector_view(dist, current_query_size * current_centroid_size), - [=] __device__(IndexType i) { - IndexType row = i / current_centroid_size, col = i % current_centroid_size; + [=] __device__(IndexType idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); - auto val = row_norms[row] + col_norms[col] - 2.0 * dist[i]; + auto val = row_norms[row] + col_norms[col] - 2.0 * dist[idx]; // due to numerical instability (especially around self-distance) // the distances here could be slightly negative, which will // cause NaN values in the subsequent sqrt. Clamp to 0 val = val * (val >= 0.0001); if (metric == raft::distance::DistanceType::L2SqrtExpanded) { val = sqrt(val); } + val = distance_epilogue(val, row, col); return val; }); + } else { + // if we're not l2 distance, and we have a distance epilogue - run it now + if constexpr (!std::is_same_v) { + auto distances_ptr = temp_distances.data(); + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(temp_distances.data(), + current_query_size * current_centroid_size), + [=] __device__(size_t idx) { + IndexType row = i + (idx / current_centroid_size); + IndexType col = j + (idx % current_centroid_size); + return distance_epilogue(distances_ptr[idx], row, col); + }); + } } select_k(temp_distances.data(), @@ -250,7 +269,10 @@ void tiled_brute_force_knn(const raft::device_resources& handle, * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm */ -template +template void brute_force_knn_impl( raft::device_resources const& handle, std::vector& input, @@ -265,7 +287,8 @@ void brute_force_knn_impl( bool rowMajorQuery = true, std::vector* translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0) + float metricArg = 0, + DistanceEpilogue distance_epilogue = raft::identity_op()) { auto userStream = handle.get_stream(); @@ -355,6 +378,7 @@ void brute_force_knn_impl( auto stream = handle.get_next_usable_stream(i); if (k <= 64 && rowMajorQuery == rowMajorIndex && rowMajorQuery == true && + std::is_same_v && (metric == raft::distance::DistanceType::L2Unexpanded || metric == raft::distance::DistanceType::L2SqrtUnexpanded || metric == raft::distance::DistanceType::L2Expanded || @@ -424,7 +448,10 @@ void brute_force_knn_impl( out_d_ptr, out_i_ptr, tiled_metric, - metricArg); + metricArg, + 0, + 0, + distance_epilogue); break; } } diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index f244d5875c..aedfc42698 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -117,6 +117,14 @@ void refine_device(raft::device_resources const& handle, n_queries, n_candidates); + // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan + // function is used in both raft::neighbors::ivf_flat::search and + // raft::neighbors::detail::refine_device. To prevent a duplicate + // instantiation of this function (which defines ~270 kernels) in the refine + // specializations, an extern template definition is provided. Please check + // and adjust the extern template definition and the instantiation when the + // below function call is edited. Search for + // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index d418d40185..1337beb68a 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -36,7 +36,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); RAFT_INST(long, float, unsigned int); RAFT_INST(uint32_t, float, int); diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 013c7359e5..161f3462c9 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -20,6 +20,13 @@ namespace raft::neighbors::ivf_flat { +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided here. Please check related function +// calls after editing template definition below. Search for +// `greppable-id-specializations-ivf-flat-search` to find them. #define RAFT_INST(T, IdxT) \ extern template auto build(raft::device_resources const& handle, \ const index_params& params, \ @@ -44,7 +51,23 @@ namespace raft::neighbors::ivf_flat { const raft::neighbors::ivf_flat::index&, \ raft::device_matrix_view, \ raft::device_matrix_view, \ - raft::device_matrix_view); + raft::device_matrix_view); \ + \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); RAFT_INST(float, int64_t); RAFT_INST(int8_t, int64_t); diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 99d688e232..c8fc6eefda 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -185,7 +185,6 @@ void k_closest_landmarks(raft::device_resources const& handle, make_device_matrix_view(query_pts, n_query_pts, inputs[0].extent(1)), make_device_matrix_view(R_knn_inds, n_query_pts, k), make_device_matrix_view(R_knn_dists, n_query_pts, k), - k, index.get_metric()); } diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh index 4e18a210d4..4a571c1447 100644 --- a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -22,6 +22,8 @@ #include "processing.cuh" #include #include +#include +#include #include #include @@ -183,13 +185,11 @@ DI void updateSortedWarpQ( } } -template Pair; @@ -222,295 +222,279 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x using namespace raft::neighbors::detail::faiss_select; typedef WarpSelect, NumWarpQ, NumThreadQ, 32> myWarpSelect; - auto rowEpilog_lambda = [m, n, numOfNN, out_dists, out_inds, mutexes] __device__( - IdxT gridStrideY) { - if (gridDim.x == 1) { return; } - - Pair* shDumpKV = nullptr; - if (useNorms) { - shDumpKV = (Pair*)(&smem[Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } - - const int lid = threadIdx.x % warpSize; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - - // 0 -> consumer done consuming the buffer. - // -1 -> consumer started consuming the buffer - // -2 -> producer done filling the buffer - // 1 -> prod acquired to fill the buffer - if (blockIdx.x == 0) { - auto cta_processed = 0; - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - __syncwarp(); - - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - - while (cta_processed < gridDim.x - 1) { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) - ; - } - __threadfence(); - __syncthreads(); + auto rowEpilog_lambda = + [m, n, &distance_op, numOfNN, out_dists, out_inds, mutexes] __device__(IdxT gridStrideY) { + if (gridDim.x == 1) { return; } + + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + const int lid = threadIdx.x % warpSize; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + + // 0 -> consumer done consuming the buffer. + // -1 -> consumer started consuming the buffer + // -2 -> producer done filling the buffer + // 1 -> prod acquired to fill the buffer + if (blockIdx.x == 0) { + auto cta_processed = 0; + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + __syncwarp(); + + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + + while (cta_processed < gridDim.x - 1) { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], -2, -1) != -2) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - otherKV.value = out_dists[rowId * numOfNN + idx]; - otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + otherKV.value = out_dists[rowId * numOfNN + idx]; + otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx]; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + shDumpKV[shMemRowId * numOfNN + idx] = otherKV; + } } } } - } - __threadfence(); - __syncthreads(); + __threadfence(); + __syncthreads(); - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } - __threadfence(); + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], 0); } + __threadfence(); // Perform merging of otherKV with topk's across warp. #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { #pragma unroll - for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { - Pair otherKV; - otherKV.value = identity; - otherKV.key = keyMax; - const auto idx = j * warpSize + lid; - if (idx < numOfNN) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + for (int j = 0; j < myWarpSelect::kNumWarpQRegisters; ++j) { + Pair otherKV; + otherKV.value = identity; + otherKV.key = keyMax; + const auto idx = j * warpSize + lid; + if (idx < numOfNN) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + otherKV = shDumpKV[shMemRowId * numOfNN + idx]; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); } } + cta_processed++; } - cta_processed++; - } #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(0xffffffff, needSort); - if (needSort) { heapArr[i]->reduce(); } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(0xffffffff, needSort); + if (needSort) { heapArr[i]->reduce(); } + } } - } - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } else { - if (threadIdx.x == 0) { - while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) - ; - } - __threadfence(); - __syncthreads(); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } else { + if (threadIdx.x == 0) { + while (atomicCAS((int*)&mutexes[gridStrideY / Policy::Mblk], 0, 1) != 0) + ; + } + __threadfence(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - if (rowId < m) { - for (int idx = lid; idx < numOfNN; idx += warpSize) { - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; - out_dists[rowId * numOfNN + idx] = KVPair.value; - out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + if (rowId < m) { + for (int idx = lid; idx < numOfNN; idx += warpSize) { + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair KVPair = shDumpKV[shMemRowId * numOfNN + idx]; + out_dists[rowId * numOfNN + idx] = KVPair.value; + out_inds[rowId * numOfNN + idx] = (IdxT)KVPair.key; + } } } - } - __threadfence(); - __syncthreads(); - - if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } - __threadfence(); - } - }; + __threadfence(); + __syncthreads(); - // epilogue operation lambda for final value calculation - auto epilog_lambda = [numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( - AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { - if (useNorms) { -#pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { -#pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; - } + if (threadIdx.x == 0) { atomicExch((int*)&mutexes[gridStrideY / Policy::Mblk], -2); } + __threadfence(); } - } + }; - Pair* shDumpKV = nullptr; - if (useNorms) { - constexpr size_t shmemSize = - Policy::SmemSize + ((Policy::Mblk + Policy::Nblk) * sizeof(DataT)); - shDumpKV = (Pair*)(&smem[shmemSize]); - } else { - shDumpKV = (Pair*)(&smem[Policy::SmemSize]); - } + // epilogue operation lambda for final value calculation + auto epilog_lambda = + [&distance_op, numOfNN, m, n, ldd, out_dists, out_inds, keyMax, identity] __device__( + AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], + DataT * regxn, + DataT * regyn, + IdxT gridStrideX, + IdxT gridStrideY) { + // Use ::template to disambiguate (See: + // https://en.cppreference.com/w/cpp/language/dependent_name) + int smem_offset = OpT::template shared_mem_size(); + Pair* shDumpKV = (Pair*)(&smem[smem_offset]); + + constexpr uint32_t mask = 0xffffffffu; + const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); + const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); + const int lid = raft::laneId(); - constexpr uint32_t mask = 0xffffffffu; - const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols); - const IdxT startx = gridStrideX + (threadIdx.x % Policy::AccThCols); - const int lid = raft::laneId(); - - myWarpSelect heapArr1(identity, keyMax, numOfNN); - myWarpSelect heapArr2(identity, keyMax, numOfNN); - myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; - if (usePrevTopKs) { - if (gridStrideX == blockIdx.x * Policy::Nblk) { - loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + myWarpSelect heapArr1(identity, keyMax, numOfNN); + myWarpSelect heapArr2(identity, keyMax, numOfNN); + myWarpSelect* heapArr[] = {&heapArr1, &heapArr2}; + if (usePrevTopKs) { + if (gridStrideX == blockIdx.x * Policy::Nblk) { + loadPrevTopKsGmemWarpQ(heapArr, out_dists, out_inds, m, numOfNN, starty); + } } - } - if (gridStrideX > blockIdx.x * Policy::Nblk) { + if (gridStrideX > blockIdx.x * Policy::Nblk) { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; - heapArr[i]->warpKTop = tempKV.value; - } + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + Pair tempKV = shDumpKV[(rowId * numOfNN) + numOfNN - 1]; + heapArr[i]->warpKTop = tempKV.value; + } - // total vals can atmost be 256, (32*8) - int numValsWarpTopK[Policy::AccRowsPerTh]; - int anyWarpTopKs = 0; + // total vals can atmost be 256, (32*8) + int numValsWarpTopK[Policy::AccRowsPerTh]; + int anyWarpTopKs = 0; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto rowId = starty + i * Policy::AccThRows; - numValsWarpTopK[i] = 0; - if (rowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto rowId = starty + i * Policy::AccThRows; + numValsWarpTopK[i] = 0; + if (rowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { numValsWarpTopK[i]++; } + } } + anyWarpTopKs += numValsWarpTopK[i]; } - anyWarpTopKs += numValsWarpTopK[i]; } - } - anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); - if (anyWarpTopKs) { - Pair* allWarpTopKs = (Pair*)(&smem[0]); - uint32_t needScanSort[Policy::AccRowsPerTh]; + anyWarpTopKs = __syncthreads_or(anyWarpTopKs > 0); + if (anyWarpTopKs) { + Pair* allWarpTopKs = (Pair*)(&smem[0]); + uint32_t needScanSort[Policy::AccRowsPerTh]; #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - needScanSort[i] = 0; - if (gmemRowId < m) { - int myVals = numValsWarpTopK[i]; - needScanSort[i] = __ballot_sync(mask, myVals > 0); - if (needScanSort[i]) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + needScanSort[i] = 0; + if (gmemRowId < m) { + int myVals = numValsWarpTopK[i]; + needScanSort[i] = __ballot_sync(mask, myVals > 0); + if (needScanSort[i]) { #pragma unroll - for (unsigned int k = 1; k <= 16; k *= 2) { - const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); - if (lid >= k) { numValsWarpTopK[i] += n; } + for (unsigned int k = 1; k <= 16; k *= 2) { + const unsigned int n = __shfl_up_sync(mask, numValsWarpTopK[i], k); + if (lid >= k) { numValsWarpTopK[i] += n; } + } } + // As each thread will know its total vals to write. + // we only store its starting location. + numValsWarpTopK[i] -= myVals; } - // As each thread will know its total vals to write. - // we only store its starting location. - numValsWarpTopK[i] -= myVals; - } - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { - if (needScanSort[i] & ((uint32_t)1 << lid)) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { + if (needScanSort[i] & ((uint32_t)1 << lid)) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - if (colId < ldd) { - if (acc[i][j] < heapArr[i]->warpKTop) { - Pair otherKV = {colId, acc[i][j]}; - allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; - numValsWarpTopK[i]++; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + if (colId < ldd) { + if (acc[i][j] < heapArr[i]->warpKTop) { + Pair otherKV = {colId, acc[i][j]}; + allWarpTopKs[rowId * (256) + numValsWarpTopK[i]] = otherKV; + numValsWarpTopK[i]++; + } } } } + __syncwarp(); + const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); + loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); + updateSortedWarpQ( + heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } - __syncwarp(); - const int finalNumVals = raft::shfl(numValsWarpTopK[i], 31); - loadWarpQShmem(heapArr[i], &shDumpKV[0], rowId, numOfNN); - updateSortedWarpQ( - heapArr[i], &allWarpTopKs[0], rowId, finalNumVals); } } - } - __syncthreads(); + __syncthreads(); #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - if (needScanSort[i]) { - const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - const auto gmemRowId = starty + i * Policy::AccThRows; - if (gmemRowId < m) { - storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + if (needScanSort[i]) { + const auto rowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + const auto gmemRowId = starty + i * Policy::AccThRows; + if (gmemRowId < m) { + storeWarpQShmem(heapArr[i], shDumpKV, rowId, numOfNN); + } } } } - } - } else { + } else { #pragma unroll - for (int i = 0; i < Policy::AccRowsPerTh; ++i) { - const auto gmemRowId = starty + i * Policy::AccThRows; - const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; - if (gmemRowId < m) { + for (int i = 0; i < Policy::AccRowsPerTh; ++i) { + const auto gmemRowId = starty + i * Policy::AccThRows; + const auto shMemRowId = (threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows; + if (gmemRowId < m) { #pragma unroll - for (int j = 0; j < Policy::AccColsPerTh; ++j) { - const auto colId = startx + j * Policy::AccThCols; - Pair otherKV = {keyMax, identity}; - if (colId < ldd) { - otherKV.value = acc[i][j]; - otherKV.key = colId; + for (int j = 0; j < Policy::AccColsPerTh; ++j) { + const auto colId = startx + j * Policy::AccThCols; + Pair otherKV = {keyMax, identity}; + if (colId < ldd) { + otherKV.value = acc[i][j]; + otherKV.key = colId; + } + heapArr[i]->add(otherKV.value, otherKV.key); } - heapArr[i]->add(otherKV.value, otherKV.key); - } - bool needSort = (heapArr[i]->numVals > 0); - needSort = __any_sync(mask, needSort); - if (needSort) { heapArr[i]->reduce(); } - storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + bool needSort = (heapArr[i]->numVals > 0); + needSort = __any_sync(mask, needSort); + if (needSort) { heapArr[i]->reduce(); } + storeWarpQShmem(heapArr[i], shDumpKV, shMemRowId, numOfNN); + } } } - } - if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { - // This is last iteration of grid stride X - loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); - storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); - } - }; + if (((gridStrideX + Policy::Nblk * gridDim.x) >= n) && gridDim.x == 1) { + // This is last iteration of grid stride X + loadAllWarpQShmem(heapArr, &shDumpKV[0], m, numOfNN); + storeWarpQGmem(heapArr, out_dists, out_inds, m, numOfNN, starty); + } + }; - raft::distance::detail::PairwiseDistances + write_out> obj(x, y, m, @@ -521,9 +505,9 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(const DataT* x ldd, _xn, _yn, - nullptr, + nullptr, // output ptr, can be null as write_out == false. smem, - core_op, + distance_op, epilog_lambda, fin_op, rowEpilog_lambda); @@ -562,38 +546,32 @@ void fusedL2UnexpKnnImpl(const DataT* x, dim3 blk(KPolicy::Nthreads); // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = x - y; - acc += diff * diff; - }; - typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2UnexpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2UnexpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor; if (numOfNN <= 32) { @@ -604,8 +582,10 @@ void fusedL2UnexpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair)); - dim3 grid = raft::distance::detail::launchConfigGenerator( + const auto sharedMemSize = + distance_op.template shared_mem_size() + KPolicy::Mblk * numOfNN * sizeof(Pair); + + dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2UnexpKnnRowMajor); if (grid.x > 1) { @@ -628,9 +608,8 @@ void fusedL2UnexpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, (int*)workspace, out_dists, @@ -753,36 +732,33 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(workspace != nullptr, "workspace is null"); dim3 blk(KPolicy::Nthreads); - // Accumulation operation lambda - auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { acc += x * y; }; typedef cub::KeyValuePair Pair; - if (isRowMajor) { - constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN distance_op{sqrt}; + raft::identity_op fin_op{}; + + if constexpr (isRowMajor) { + constexpr auto fusedL2ExpKnn32RowMajor = fusedL2kNN; - constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + constexpr auto fusedL2ExpKnn64RowMajor = fusedL2kNN; + isRowMajor>; auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor; if (numOfNN <= 32) { @@ -793,9 +769,8 @@ void fusedL2ExpKnnImpl(const DataT* x, ASSERT(numOfNN <= 64, "fusedL2kNN: num of nearest neighbors must be <= 64"); } - const auto sharedMemSize = KPolicy::SmemSize + - ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) + - (KPolicy::Mblk * numOfNN * sizeof(Pair)); + const auto sharedMemSize = + distance_op.template shared_mem_size() + (KPolicy::Mblk * numOfNN * sizeof(Pair)); dim3 grid = raft::distance::detail::launchConfigGenerator( m, n, sharedMemSize, fusedL2ExpKnnRowMajor); int32_t* mutexes = nullptr; @@ -835,9 +810,8 @@ void fusedL2ExpKnnImpl(const DataT* x, lda, ldb, ldd, - core_lambda, - raft::identity_op{}, - sqrt, + distance_op, + fin_op, (uint32_t)numOfNN, mutexes, out_dists, diff --git a/cpp/include/raft/spatial/knn/knn.cuh b/cpp/include/raft/spatial/knn/knn.cuh index 692d262043..a7bbfd9500 100644 --- a/cpp/include/raft/spatial/knn/knn.cuh +++ b/cpp/include/raft/spatial/knn/knn.cuh @@ -153,12 +153,12 @@ template case SelectKAlgo::RADIX_8_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::RADIX_11_BITS: matrix::detail::select::radix::select_k( - in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, stream); + in_keys, in_values, n_inputs, input_len, k, out_keys, out_values, select_min, true, stream); break; case SelectKAlgo::WARP_SORT: diff --git a/cpp/include/raft/util/arch.cuh b/cpp/include/raft/util/arch.cuh index 8c48b87269..dc35b10063 100644 --- a/cpp/include/raft/util/arch.cuh +++ b/cpp/include/raft/util/arch.cuh @@ -15,25 +15,27 @@ */ #pragma once -namespace raft::arch { +#include // RAFT_CUDA_TRY -/* raft::arch provides the following facilities: +namespace raft::util::arch { + +/* raft::util::arch provides the following facilities: * - * - raft::arch::SM_XX : hardcoded compile-time constants for various compute - * architectures. The values raft::arch::SM_min and raft::arch::SM_future + * - raft::util::arch::SM_XX : hardcoded compile-time constants for various compute + * architectures. The values raft::util::arch::SM_min and raft::util::arch::SM_future * represent architectures that are always smaller and larger (respectively) * than any architecture that can be encountered in practice. * - * - raft::arch::SM_compute_arch : a compile-time value for the *current* + * - raft::util::arch::SM_compute_arch : a compile-time value for the *current* * compute architecture that a kernel is compiled with. It can only be used * inside kernels with a template argument. * - * - raft::arch::kernel_runtime_arch : a function that computes at *run-time* + * - raft::util::arch::kernel_runtime_arch : a function that computes at *run-time* * which version of a kernel will launch (i.e., it will return the compute * architecture of the version of the kernel that will be launched by the * driver). * - * - raft::arch::SM_range : a compile-time value to represent an open interval + * - raft::util::arch::SM_range : a compile-time value to represent an open interval * of compute architectures. This can be used to check if the current * compile-time architecture is in a specified compatibility range. */ @@ -46,9 +48,6 @@ struct SM_generic { public: __host__ __device__ constexpr int value() const { return n; } }; - -// A dummy kernel that is used to determine the runtime architecture. -__global__ inline void dummy_runtime_kernel() {} } // namespace detail // A list of architectures that RAPIDS explicitly builds for (SM60, ..., SM90) @@ -119,7 +118,7 @@ struct SM_runtime { inline SM_runtime kernel_runtime_arch(void* kernel) { cudaFuncAttributes attributes; - cudaFuncGetAttributes(&attributes, kernel); + RAFT_CUDA_TRY(cudaFuncGetAttributes(&attributes, kernel)); return SM_runtime(10 * attributes.ptxVersion); } @@ -143,4 +142,4 @@ struct SM_range { } }; -} // namespace raft::arch +} // namespace raft::util::arch diff --git a/cpp/include/raft/util/cuda_dev_essentials.cuh b/cpp/include/raft/util/cuda_dev_essentials.cuh new file mode 100644 index 0000000000..bb9ebbba59 --- /dev/null +++ b/cpp/include/raft/util/cuda_dev_essentials.cuh @@ -0,0 +1,117 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions for use in __device__ code. The +// scope is necessarily limited to ensure that compilation times are minimized. +// Please make sure not to include large / expensive files from here. + +namespace raft { + +/** helper macro for device inlined functions */ +#define DI inline __device__ +#define HDI inline __host__ __device__ +#define HD __host__ __device__ + +/** + * @brief Provide a ceiling division operation ie. ceil(a / b) + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType ceildiv(IntType a, IntType b) +{ + return (a + b - 1) / b; +} + +/** + * @brief Provide an alignment function ie. ceil(a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignTo(IntType a, IntType b) +{ + return ceildiv(a, b) * b; +} + +/** + * @brief Provide an alignment function ie. (a / b) * b + * @tparam IntType supposed to be only integers for now! + */ +template +constexpr HDI IntType alignDown(IntType a, IntType b) +{ + return (a / b) * b; +} + +/** + * @brief Check if the input is a power of 2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI bool isPo2(IntType num) +{ + return (num && !(num & (num - 1))); +} + +/** + * @brief Give logarithm of the number to base-2 + * @tparam IntType data type (checked only for integers) + */ +template +constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) +{ + return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); +} + +/** number of threads per warp */ +static const int WarpSize = 32; + +/** get the laneId of the current thread */ +DI int laneId() +{ + int id; + asm("mov.s32 %0, %%laneid;" : "=r"(id)); + return id; +} + +/** Device function to apply the input lambda across threads in the grid */ +template +DI void forEach(int num, L lambda) +{ + int idx = (blockDim.x * blockIdx.x) + threadIdx.x; + const int numThreads = blockDim.x * gridDim.x; +#pragma unroll + for (int itr = 0; itr < ItemsPerThread; ++itr, idx += numThreads) { + if (idx < num) lambda(idx, itr); + } +} + +/** + * @brief Swap two values + * @tparam T the datatype of the values + * @param a first input + * @param b second input + */ +template +HDI void swapVals(T& a, T& b) +{ + T tmp = a; + a = b; + b = tmp; +} + +} // namespace raft diff --git a/cpp/include/raft/util/cuda_rt_essentials.hpp b/cpp/include/raft/util/cuda_rt_essentials.hpp new file mode 100644 index 0000000000..e5f3af4e61 --- /dev/null +++ b/cpp/include/raft/util/cuda_rt_essentials.hpp @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +// This file provides a few essential functions that wrap the CUDA runtime API. +// The scope is necessarily limited to ensure that compilation times are +// minimized. Please make sure not to include large / expensive files from here. + +#include +#include + +namespace raft { + +/** + * @brief Exception thrown when a CUDA error is encountered. + */ +struct cuda_error : public raft::exception { + explicit cuda_error(char const* const message) : raft::exception(message) {} + explicit cuda_error(std::string const& message) : raft::exception(message) {} +}; + +} // namespace raft + +/** + * @brief Error checking macro for CUDA runtime API functions. + * + * Invokes a CUDA runtime API function call, if the call does not return + * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an + * exception detailing the CUDA error that occurred + * + */ +#define RAFT_CUDA_TRY(call) \ + do { \ + cudaError_t const status = call; \ + if (status != cudaSuccess) { \ + cudaGetLastError(); \ + std::string msg{}; \ + SET_ERROR_MSG(msg, \ + "CUDA error encountered at: ", \ + "call='%s', Reason=%s:%s", \ + #call, \ + cudaGetErrorName(status), \ + cudaGetErrorString(status)); \ + throw raft::cuda_error(msg); \ + } \ + } while (0) diff --git a/cpp/include/raft/util/cuda_utils.cuh b/cpp/include/raft/util/cuda_utils.cuh index 5be9dc999a..687a6b4651 100644 --- a/cpp/include/raft/util/cuda_utils.cuh +++ b/cpp/include/raft/util/cuda_utils.cuh @@ -23,113 +23,10 @@ #include #include #include - -#ifndef ENABLE_MEMCPY_ASYNC -// enable memcpy_async interface by default for newer GPUs -#if __CUDA_ARCH__ >= 800 -#define ENABLE_MEMCPY_ASYNC 1 -#endif -#else // ENABLE_MEMCPY_ASYNC -// disable memcpy_async for all older GPUs -#if __CUDA_ARCH__ < 800 -#define ENABLE_MEMCPY_ASYNC 0 -#endif -#endif // ENABLE_MEMCPY_ASYNC +#include namespace raft { -/** helper macro for device inlined functions */ -#define DI inline __device__ -#define HDI inline __host__ __device__ -#define HD __host__ __device__ - -/** - * @brief Provide a ceiling division operation ie. ceil(a / b) - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType ceildiv(IntType a, IntType b) -{ - return (a + b - 1) / b; -} - -/** - * @brief Provide an alignment function ie. ceil(a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignTo(IntType a, IntType b) -{ - return ceildiv(a, b) * b; -} - -/** - * @brief Provide an alignment function ie. (a / b) * b - * @tparam IntType supposed to be only integers for now! - */ -template -constexpr HDI IntType alignDown(IntType a, IntType b) -{ - return (a / b) * b; -} - -/** - * @brief Check if the input is a power of 2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI bool isPo2(IntType num) -{ - return (num && !(num & (num - 1))); -} - -/** - * @brief Give logarithm of the number to base-2 - * @tparam IntType data type (checked only for integers) - */ -template -constexpr HDI IntType log2(IntType num, IntType ret = IntType(0)) -{ - return num <= IntType(1) ? ret : log2(num >> IntType(1), ++ret); -} - -/** Device function to apply the input lambda across threads in the grid */ -template -DI void forEach(int num, L lambda) -{ - int idx = (blockDim.x * blockIdx.x) + threadIdx.x; - const int numThreads = blockDim.x * gridDim.x; -#pragma unroll - for (int itr = 0; itr < ItemsPerThread; ++itr, idx += numThreads) { - if (idx < num) lambda(idx, itr); - } -} - -/** number of threads per warp */ -static const int WarpSize = 32; - -/** get the laneId of the current thread */ -DI int laneId() -{ - int id; - asm("mov.s32 %0, %%laneid;" : "=r"(id)); - return id; -} - -/** - * @brief Swap two values - * @tparam T the datatype of the values - * @param a first input - * @param b second input - */ -template -HDI void swapVals(T& a, T& b) -{ - T tmp = a; - a = b; - b = tmp; -} - /** Device function to have atomic add support for older archs */ template DI void myAtomicAdd(Type* address, Type val) diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 0feb188ad8..0a7ca23028 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -25,6 +25,7 @@ #pragma once #include +#include #include #include #include @@ -40,42 +41,7 @@ #include #include #include - -namespace raft { - -/** - * @brief Exception thrown when a CUDA error is encountered. - */ -struct cuda_error : public raft::exception { - explicit cuda_error(char const* const message) : raft::exception(message) {} - explicit cuda_error(std::string const& message) : raft::exception(message) {} -}; - -} // namespace raft - -/** - * @brief Error checking macro for CUDA runtime API functions. - * - * Invokes a CUDA runtime API function call, if the call does not return - * cudaSuccess, invokes cudaGetLastError() to clear the error and throws an - * exception detailing the CUDA error that occurred - * - */ -#define RAFT_CUDA_TRY(call) \ - do { \ - cudaError_t const status = call; \ - if (status != cudaSuccess) { \ - cudaGetLastError(); \ - std::string msg{}; \ - SET_ERROR_MSG(msg, \ - "CUDA error encountered at: ", \ - "call='%s', Reason=%s:%s", \ - #call, \ - cudaGetErrorName(status), \ - cudaGetErrorString(status)); \ - throw raft::cuda_error(msg); \ - } \ - } while (0) +#include // FIXME: Remove after consumers rename #ifndef CUDA_TRY diff --git a/cpp/include/raft/util/device_loads_stores.cuh b/cpp/include/raft/util/device_loads_stores.cuh index 2b87c44d60..c9bda26b81 100644 --- a/cpp/include/raft/util/device_loads_stores.cuh +++ b/cpp/include/raft/util/device_loads_stores.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,8 @@ #pragma once -#include +#include // uintX_t +#include // DI namespace raft { diff --git a/cpp/include/raft_runtime/neighbors/brute_force.hpp b/cpp/include/raft_runtime/neighbors/brute_force.hpp index 19904f4f78..12da6ff101 100644 --- a/cpp/include/raft_runtime/neighbors/brute_force.hpp +++ b/cpp/include/raft_runtime/neighbors/brute_force.hpp @@ -21,18 +21,17 @@ namespace raft::runtime::neighbors::brute_force { -#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ - void knn(raft::device_resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - int k, \ - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ - std::optional metric_arg = std::make_optional(2.0f), \ +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ + std::optional metric_arg = std::make_optional(2.0f), \ std::optional global_id_offset = std::nullopt); -RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major); +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); #undef RAFT_INST_BFKNN diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index ede6382c33..188122c9b4 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -33,7 +33,8 @@ struct params { size_t len; int k; bool select_min; - bool use_index_input = true; + bool use_index_input = true; + bool use_same_leading_bits = false; }; inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& @@ -42,7 +43,8 @@ inline auto operator<<(std::ostream& os, const params& ss) -> std::ostream& os << ", len: " << ss.len; os << ", k: " << ss.k; os << (ss.select_min ? ", asc" : ", dsc"); - os << (ss.use_index_input ? "}" : ", no-input-index}"); + os << (ss.use_index_input ? "" : ", no-input-index"); + os << (ss.use_same_leading_bits ? ", same-leading-bits}" : "}"); return os; } @@ -50,6 +52,7 @@ enum class Algo { kPublicApi, kRadix8bits, kRadix11bits, + kRadix11bitsExtraPass, kWarpAuto, kWarpImmediate, kWarpFiltered, @@ -63,6 +66,7 @@ inline auto operator<<(std::ostream& os, const Algo& algo) -> std::ostream& case Algo::kPublicApi: return os << "kPublicApi"; case Algo::kRadix8bits: return os << "kRadix8bits"; case Algo::kRadix11bits: return os << "kRadix11bits"; + case Algo::kRadix11bitsExtraPass: return os << "kRadix11bitsExtraPass"; case Algo::kWarpAuto: return os << "kWarpAuto"; case Algo::kWarpImmediate: return os << "kWarpImmediate"; case Algo::kWarpFiltered: return os << "kWarpFiltered"; @@ -103,11 +107,38 @@ void select_k_impl(const device_resources& handle, } } case Algo::kRadix8bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); case Algo::kRadix11bits: - return detail::select::radix::select_k( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + true, // fused_last_filter + stream); + case Algo::kRadix11bitsExtraPass: + return detail::select::radix::select_k(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + false, // fused_last_filter + stream); case Algo::kWarpAuto: return detail::select::warpsort::select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); diff --git a/cpp/src/distance/specializations/detail/00_write_template.py b/cpp/src/distance/specializations/detail/00_write_template.py new file mode 100644 index 0000000000..3f2f853569 --- /dev/null +++ b/cpp/src/distance/specializations/detail/00_write_template.py @@ -0,0 +1,159 @@ +#!/usr/bin/env python3 + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +template = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +INCLUDE_SM_HEADERS + +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point( + OpT, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="ops::russel_rao_distance_op", + archs = [60], + ), +] + +def fill_in(s, template): + for k, v in template.items(): + s = s.replace(k, v) + return s + +def fill_include_sm_headers(op_instance): + include_headers ="\n".join([ + f"#include " + for arch in op_instance["archs"] + ]) + + return { + "path_prefix": op_instance["path_prefix"], + "OpT": op_instance["OpT"], + "INCLUDE_SM_HEADERS": include_headers + } + +for op_instance in op_instances: + op_instance = fill_include_sm_headers(op_instance) + + for data_type_instance in data_type_instances: + op_data_instance = { + k : fill_in(v, data_type_instance) + for k, v in op_instance.items() + } + instance = { + **op_data_instance, + **data_type_instance, + "FinopT": "decltype(raft::identity_op())", + } + + text = fill_in(template, instance) + + path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) + with open(path, "w") as f: + f.write(text) diff --git a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu index 4e9e608792..037d218178 100644 --- a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu @@ -14,24 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu index 6dfc385e55..0ed8ea7bb0 100644 --- a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::canberra_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu index 2df77a4b5d..0c11f0621e 100644 --- a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu index 76ed00afa6..396e158554 100644 --- a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::correlation_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu index 3e0bcb92ed..e9afb6f563 100644 --- a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu @@ -14,26 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu index 23131ce2c7..1033c491d6 100644 --- a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu @@ -14,26 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::cosine_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu index b618fd024c..195115914d 100644 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu index 18e7aad9e9..a74c6c404e 100644 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hamming_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu index 08ab20cfe5..bac1dd7bd0 100644 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu @@ -14,27 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu index 79eed075fb..77c113b1a9 100644 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::hellinger_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu index ed84ee6dc4..188e52c152 100644 --- a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void + pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu index a241af767c..b0afbf7bb2 100644 --- a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void + pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::jensen_shannon_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu index c4c944d123..f06ae85414 100644 --- a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu index aa1db5a837..00d5a5ee5b 100644 --- a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::kl_divergence_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu index 391a1c2aa4..5c235316da 100644 --- a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu index 7b45e52ca1..fb293ca83d 100644 --- a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l1_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu index 8c5f746fa2..2c02f0224f 100644 --- a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu @@ -14,24 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +namespace raft::distance::detail { + +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu index c266125f98..85e25a25ca 100644 --- a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu @@ -14,25 +14,21 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_exp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu deleted file mode 100644 index 399b120527..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu deleted file mode 100644 index 66de212b8e..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_expanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu deleted file mode 100644 index 562d93b2de..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { - -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu deleted file mode 100644 index 386bbafc5f..0000000000 --- a/cpp/src/distance/specializations/detail/l2_sqrt_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu index 7733c3af48..5b4d995d14 100644 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu index 4ea18d31de..a63c3f0bb8 100644 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l2_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu index 74414f8fd6..831167523f 100644 --- a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); -} // namespace detail -} // namespace distance -} // namespace raft +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu index e418fc455f..02e667cbe3 100644 --- a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::l_inf_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu index 402cb51b7e..ebd71065ec 100644 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { +namespace raft::distance::detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu index 7efe2b3349..b94a81fdce 100644 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::lp_unexp_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu index b1e6f5e1f4..6f952fcc37 100644 --- a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu @@ -14,26 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + double, + double, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu index 1e12bcd705..3223ce33a7 100644 --- a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu @@ -14,25 +14,20 @@ * limitations under the License. */ -#include -#include +#include // raft::identity_op +#include // ops::* +#include // pairwise_matrix_instantiation_point +#include -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); +namespace raft::distance::detail { -} // namespace detail -} // namespace distance -} // namespace raft +template void pairwise_matrix_instantiation_point, + int, + float, + float, + decltype(raft::identity_op())>( + ops::russel_rao_distance_op, + pairwise_matrix_params, + cudaStream_t); + +} // namespace raft::distance::detail diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu index b0411a59ce..88545b3607 100644 --- a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float.cu @@ -14,8 +14,6 @@ * limitations under the License. */ -#pragma once - #include #include #include @@ -24,30 +22,27 @@ #include +#include + namespace raft::runtime::neighbors::brute_force { -#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ - void knn(raft::device_resources const& handle, \ - std::vector> index, \ - raft::device_matrix_view search, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, \ - std::optional metric_arg = std::make_optional(2.0f), \ - std::optional global_id_offset = std::nullopt) \ - { \ - raft::neighbors::brute_force::knn(handle, \ - index, \ - search, \ - indices, \ - distances, \ - static_cast(indices.extent(1)), \ - metric, \ - metric_arg, \ - global_id_offset); \ +#define RAFT_INST_BFKNN(IDX_T, DATA_T, MATRIX_IDX_T, INDEX_LAYOUT, SEARCH_LAYOUT) \ + void knn(raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset) \ + { \ + std::vector> vec; \ + vec.push_back(index); \ + raft::neighbors::brute_force::knn( \ + handle, vec, search, indices, distances, metric, metric_arg, global_id_offset); \ } -RAFT_INST_BFKNN(int64_t, float, uint32_t, raft::row_major, raft::row_major); +RAFT_INST_BFKNN(int64_t, float, int64_t, raft::row_major, raft::row_major); #undef RAFT_INST_BFKNN diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu index 07810aa576..04aa42c9f1 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu index 0cb873b40a..a8b9d4299a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(long, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu index f8a69b896f..c97e6e936a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu index 3c23d1f3e0..87451c385a 100644 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu @@ -32,7 +32,8 @@ namespace raft::neighbors::detail { bool rowMajorQuery, \ std::vector* translations, \ raft::distance::DistanceType metric, \ - float metricArg); + float metricArg, \ + raft::identity_op); RAFT_INST(uint32_t, float, unsigned int); #undef RAFT_INST } // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu index 6de65546c8..dce7083139 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu @@ -18,12 +18,37 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan +// function is used in both raft::neighbors::ivf_flat::search and +// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation +// of this function (which defines ~270 kernels) in the refine specializations, +// an extern template definition is provided. To make sure +// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate +// it below. Please check related function calls after editing template +// definition below. Search for `greppable-id-specializations-ivf-flat-search` +// to find them. +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(float, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu index 8eda240ccd..b03d878bae 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(int8_t, int64_t); diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu index 8ff6533628..2d42bae0d1 100644 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu @@ -18,12 +18,28 @@ namespace raft::neighbors::ivf_flat { -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ +#define RAFT_MAKE_INSTANCE(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ + T, \ + typename raft::spatial::knn::detail::utils::config::value_t, \ + IdxT>(const index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream); \ + \ + template void search(raft::device_resources const&, \ + raft::neighbors::ivf_flat::search_params const&, \ + const raft::neighbors::ivf_flat::index&, \ + raft::device_matrix_view, \ + raft::device_matrix_view, \ raft::device_matrix_view); RAFT_MAKE_INSTANCE(uint8_t, int64_t); diff --git a/cpp/template/CMakeLists.txt b/cpp/template/CMakeLists.txt new file mode 100644 index 0000000000..501a5c9503 --- /dev/null +++ b/cpp/template/CMakeLists.txt @@ -0,0 +1,38 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) + +# ------------- configure rapids-cmake --------------# + +include(cmake/thirdparty/fetch_rapids.cmake) +include(rapids-cmake) +include(rapids-cpm) +include(rapids-cuda) +include(rapids-export) +include(rapids-find) + +# ------------- configure project --------------# + +rapids_cuda_init_architectures(test_raft) + +project(test_raft LANGUAGES CXX CUDA) + +# ------------- configure raft -----------------# + +rapids_cpm_init() +include(cmake/thirdparty/get_raft.cmake) + +# -------------- compile tasks ----------------- # +add_executable(TEST_RAFT src/test_distance.cu) +target_link_libraries(TEST_RAFT PRIVATE raft::raft raft::compiled) diff --git a/cpp/template/README.md b/cpp/template/README.md new file mode 100644 index 0000000000..348dff270a --- /dev/null +++ b/cpp/template/README.md @@ -0,0 +1,18 @@ +# Example RAFT Project Template + +This template project provides a drop-in sample to either start building a new application with, or using RAFT in an existing CMake project. + +First, please refer to our [installation docs](https://docs.rapids.ai/api/raft/stable/build.html#cuda-gpu-requirements) for the minimum requirements to use RAFT. + +Once the minimum requirements are satisfied, this example template application can be built with the provided `build.sh` script. This is a bash script that calls the appropriate CMake commands, so you can look into it to see the typical CMake based build workflow. + +This directory (`RAFT_SOURCE/cpp/template`) can be copied directly in order to build a new application with RAFT. + +RAFT can be integrated into an existing CMake project by copying the contents in the `configure rapids-cmake` and `configure raft` sections of the provided `CMakeLists.txt` into your project, along with `cmake/thirdparty/get_raft.cmake`. + +Make sure to link against the appropriate Cmake targets. Use `raft::raft`to add make the headers available and `raft::compiled` when utilizing the shared library. + +```cmake +target_link_libraries(your_app_target PRIVATE raft::raft raft::compiled) +``` + diff --git a/cpp/template/build.sh b/cpp/template/build.sh new file mode 100755 index 0000000000..3ac00fc9af --- /dev/null +++ b/cpp/template/build.sh @@ -0,0 +1,41 @@ +#!/bin/bash + +# Copyright (c) 2023, NVIDIA CORPORATION. + +# raft empty project template build script + +# Abort script on first error +set -e + +PARALLEL_LEVEL=${PARALLEL_LEVEL:=`nproc`} + +BUILD_TYPE=Release +BUILD_DIR=build/ + +RAFT_REPO_REL="" +EXTRA_CMAKE_ARGS="" +set -e + + +if [[ ${RAFT_REPO_REL} != "" ]]; then + RAFT_REPO_PATH="`readlink -f \"${RAFT_REPO_REL}\"`" + EXTRA_CMAKE_ARGS="${EXTRA_CMAKE_ARGS} -DCPM_raft_SOURCE=${RAFT_REPO_PATH}" +fi + +if [ "$1" == "clean" ]; then + rm -rf build + exit 0 +fi + +mkdir -p $BUILD_DIR +cd $BUILD_DIR + +cmake \ + -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DRAFT_NVTX=OFF \ + -DCMAKE_CUDA_ARCHITECTURES="NATIVE" \ + -DCMAKE_EXPORT_COMPILE_COMMANDS=ON \ + ${EXTRA_CMAKE_ARGS} \ + ../ + +cmake --build . -j${PARALLEL_LEVEL} diff --git a/cpp/template/cmake/thirdparty/fetch_rapids.cmake b/cpp/template/cmake/thirdparty/fetch_rapids.cmake new file mode 100644 index 0000000000..40ba83be9e --- /dev/null +++ b/cpp/template/cmake/thirdparty/fetch_rapids.cmake @@ -0,0 +1,21 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +# Use this variable to update RAPIDS and RAFT versions +set(RAPIDS_VERSION "23.04") + +if(NOT EXISTS ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) + file(DOWNLOAD https://raw.githubusercontent.com/rapidsai/rapids-cmake/branch-${RAPIDS_VERSION}/RAPIDS.cmake + ${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) +endif() +include(${CMAKE_CURRENT_BINARY_DIR}/RAFT_RAPIDS.cmake) diff --git a/cpp/template/cmake/thirdparty/get_raft.cmake b/cpp/template/cmake/thirdparty/get_raft.cmake new file mode 100644 index 0000000000..5463942adf --- /dev/null +++ b/cpp/template/cmake/thirdparty/get_raft.cmake @@ -0,0 +1,62 @@ +# ============================================================================= +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except +# in compliance with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software distributed under the License +# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing permissions and limitations under +# the License. + +# Use RAPIDS_VERSION from cmake/thirdparty/fetch_rapids.cmake +set(RAFT_VERSION "${RAPIDS_VERSION}") +set(RAFT_FORK "rapidsai") +set(RAFT_PINNED_TAG "branch-${RAPIDS_VERSION}") + +function(find_and_configure_raft) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY ENABLE_NVTX ENABLE_MNMG_DEPENDENCIES) + cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + set(RAFT_COMPONENTS "") + if(PKG_COMPILE_LIBRARY) + string(APPEND RAFT_COMPONENTS " compiled") + endif() + + if(PKG_ENABLE_MNMG_DEPENDENCIES) + string(APPEND RAFT_COMPONENTS " distributed") + endif() + + #----------------------------------------------------- + # Invoke CPM find_package() + #----------------------------------------------------- + rapids_cpm_find(raft ${PKG_VERSION} + GLOBAL_TARGETS raft::raft + BUILD_EXPORT_SET raft-template-exports + INSTALL_EXPORT_SET raft-template-exports + COMPONENTS ${RAFT_COMPONENTS} + CPM_ARGS + GIT_REPOSITORY https://github.com/${PKG_FORK}/raft.git + GIT_TAG ${PKG_PINNED_TAG} + SOURCE_SUBDIR cpp + OPTIONS + "BUILD_TESTS OFF" + "BUILD_BENCH OFF" + "RAFT_NVTX ${ENABLE_NVTX}" + "RAFT_COMPILE_LIBRARY ${PKG_COMPILE_LIBRARY}" + ) +endfunction() + +# Change pinned tag here to test a commit in CI +# To use a different RAFT locally, set the CMake variable +# CPM_raft_SOURCE=/path/to/local/raft +find_and_configure_raft(VERSION ${RAFT_VERSION}.00 + FORK ${RAFT_FORK} + PINNED_TAG ${RAFT_PINNED_TAG} + COMPILE_LIBRARY ON + ENABLE_MNMG_DEPENDENCIES OFF + ENABLE_NVTX OFF +) diff --git a/cpp/template/src/test_distance.cu b/cpp/template/src/test_distance.cu new file mode 100644 index 0000000000..b86dde70e5 --- /dev/null +++ b/cpp/template/src/test_distance.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#ifdef RAFT_COMPILED +#include +#endif + +int main() +{ + raft::device_resources handle; + + int n_samples = 5000; + int n_features = 50; + + auto input = raft::make_device_matrix(handle, n_samples, n_features); + auto labels = raft::make_device_vector(handle, n_samples); + auto output = raft::make_device_matrix(handle, n_samples, n_samples); + + raft::random::make_blobs(handle, input.view(), labels.view()); + + auto metric = raft::distance::DistanceType::L2SqrtExpanded; + raft::distance::pairwise_distance(handle, input.view(), input.view(), output.view(), metric); +} diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 0e084f2ad8..438e212fbd 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -16,16 +16,24 @@ #include "../test_utils.cuh" #include -#include -#include -#include -#include -#include -#include +#include // common::nvtx::range + +#include // make_device_matrix_view +#include // raft::device_resources +#include // raft::sqrt +#include // raft::distance::DistanceType +#include +#include // rmm::device_uvector + +// When the distance library is precompiled, include only the raft_runtime +// headers. This way, a small change in one of the kernel internals does not +// trigger a rebuild of the test files (it of course still triggers a rebuild of +// the raft specializations) #if defined RAFT_COMPILED -#include +#include +#else +#include #endif -#include namespace raft { namespace distance { @@ -409,6 +417,25 @@ template return os; } +// TODO: Remove when mdspan-based raft::runtime::distance::pairwise_distance is +// implemented. +// +// Context: +// https://github.com/rapidsai/raft/issues/1338 +template +constexpr bool layout_to_row_major(); + +template <> +constexpr bool layout_to_row_major() +{ + return true; +} +template <> +constexpr bool layout_to_row_major() +{ + return false; +} + template void distanceLauncher(raft::device_resources const& handle, DataType* x, @@ -422,12 +449,23 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { +#if defined RAFT_COMPILED + // TODO: Implement and use mdspan-based + // raft::runtime::distance::pairwise_distance here. + // + // Context: + // https://github.com/rapidsai/raft/issues/1338 + bool row_major = layout_to_row_major(); + raft::runtime::distance::pairwise_distance( + handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); +#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); +#endif } template @@ -523,9 +561,25 @@ class BigMatrixDistanceTest : public ::testing::Test { auto testInfo = testing::UnitTest::GetInstance()->current_test_info(); common::nvtx::range fun_scope("test::%s/%s", testInfo->test_suite_name(), testInfo->name()); + void pairwise_distance(raft::device_resources const& handle, + float* x, + float* y, + float* dists, + int m, + int n, + int k, + raft::distance::DistanceType metric, + bool isRowMajor, + float metric_arg); + constexpr bool row_major = true; + constexpr float metric_arg = 0.0f; +#if defined RAFT_COMPILED + raft::runtime::distance::pairwise_distance( + handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); +#else raft::distance::distance( - handle, x.data(), x.data(), dist.data(), m, n, k, true, 0.0f); - + handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); +#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 4a74d7f16a..383ad39319 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -182,22 +182,20 @@ class FusedL2NNTest : public ::testing::TestWithParam> { int m = params.m; int n = params.n; int k = params.k; - MinAndDistanceReduceOp redOp; - fusedL2NN, int>( - out, - x.data(), - y.data(), - xn.data(), - yn.data(), - m, - n, - k, - (void*)workspace.data(), - redOp, - raft::distance::KVPMinReduce(), - Sqrt, - true, - stream); + + const bool init_out_buffer = true; + fusedL2NNMinReduce, int>(out, + x.data(), + y.data(), + xn.data(), + yn.data(), + m, + n, + k, + (void*)workspace.data(), + Sqrt, + init_out_buffer, + stream); RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } }; diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 392464eb27..2a40d70abc 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -332,6 +332,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Values(select::Algo::kPublicApi, select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed))); @@ -426,6 +427,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -440,6 +442,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT testing::Combine(inputs_random_longlist, testing::Values(select::Algo::kRadix8bits, select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, select::Algo::kWarpImmediate, select::Algo::kWarpFiltered, select::Algo::kWarpDistributed, @@ -451,7 +454,11 @@ TEST_P(ReferencedRandomDoubleInt, LargeSize) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, ReferencedRandomDoubleInt, - testing::Combine(inputs_random_largesize, testing::Values(select::Algo::kWarpAuto))); + testing::Combine(inputs_random_largesize, + testing::Values(select::Algo::kWarpAuto, + select::Algo::kRadix8bits, + select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = SelectK::params_random>; @@ -459,6 +466,7 @@ TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits))); + testing::Values(select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); } // namespace raft::matrix diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 9b51d585de..46ef3a9150 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -121,7 +121,6 @@ void compute_bfknn(const raft::device_resources& handle, make_device_matrix_view(X2, n_query_rows, d), make_device_matrix_view(inds, n_query_rows, k), make_device_matrix_view(dists, n_query_rows, k), - k, metric); } diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index 4bb977432c..bcd4b9cb0b 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -96,7 +96,7 @@ class KNNTest : public ::testing::TestWithParam { raft::make_device_matrix_view(distances_.data(), rows_, k_); auto metric = raft::distance::DistanceType::L2Unexpanded; - knn(handle, index, search, indices, distances, k_, metric, std::make_optional(0)); + knn(handle, index, search, indices, distances, metric, std::make_optional(0)); build_actual_output<<>>( actual_labels_.data(), rows_, k_, search_labels_.data(), indices_.data()); diff --git a/dependencies.yaml b/dependencies.yaml index 9fbf26bcd1..dd361a0cdf 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -7,11 +7,14 @@ files: arch: [x86_64] includes: - build + - build_pylibraft - cudatoolkit - develop - docs - - run - - test_python + - run_raft_dask + - run_pylibraft + - test_python_common + - test_pylibraft test_cpp: output: none includes: @@ -21,7 +24,8 @@ files: includes: - cudatoolkit - py_version - - test_python + - test_python_common + - test_pylibraft checks: output: none includes: @@ -33,6 +37,54 @@ files: - cudatoolkit - docs - py_version + py_build_pylibraft: + output: pyproject + pyproject_dir: python/pylibraft + extras: + table: build-system + includes: + - build + - build_pylibraft + - build_wheels + py_run_pylibraft: + output: pyproject + pyproject_dir: python/pylibraft + extras: + table: project + includes: + - run_pylibraft + py_test_pylibraft: + output: pyproject + pyproject_dir: python/pylibraft + extras: + table: project.optional-dependencies + key: test + includes: + - test_python_common + - test_pylibraft + py_build_raft_dask: + output: pyproject + pyproject_dir: python/raft-dask + extras: + table: build-system + includes: + - build + - build_wheels + py_run_raft_dask: + output: pyproject + pyproject_dir: python/raft-dask + extras: + table: project + includes: + - run_raft_dask + py_test_raft_dask: + output: pyproject + pyproject_dir: python/raft-dask + extras: + table: project.optional-dependencies + key: test + includes: + - test_python_common channels: - rapidsai - rapidsai-nightly @@ -42,10 +94,9 @@ channels: dependencies: build: common: - - output_types: [conda, requirements] + - output_types: [conda, requirements, pyproject] packages: - cmake>=3.23.1,!=3.25.0 - - cuda-python >=11.7.1,<12.0 - cython>=0.29,<0.30 - ninja - scikit-build>=0.13.1 @@ -53,6 +104,7 @@ dependencies: packages: - c-compiler - cxx-compiler + - nccl>=2.9.9 specific: - output_types: conda matrices: @@ -66,6 +118,12 @@ dependencies: packages: - gcc_linux-aarch64=11.* - sysroot_linux-aarch64==2.17 + build_pylibraft: + common: + - output_types: [conda, requirements, pyproject] + packages: + - &cuda_python cuda-python >=11.7.1,<12.0 + - &rmm rmm==23.4.* checks: common: - output_types: [conda, requirements] @@ -150,6 +208,12 @@ dependencies: - recommonmark - sphinx-copybutton - sphinx-markdown-tables + build_wheels: + common: + - output_types: pyproject + packages: + - wheel + - setuptools py_version: specific: - output_types: conda @@ -169,23 +233,41 @@ dependencies: - matrix: packages: - python>=3.8,<3.11 - run: + run_pylibraft: common: - - output_types: [conda] + - output_types: [conda, pyproject] + packages: + - &numpy numpy>=1.21 + - *cuda_python + - *rmm + run_raft_dask: + common: + - output_types: [conda, pyproject] packages: - dask>=2023.1.1 + - dask-cuda==23.4.* - distributed>=2023.1.1 + - joblib>=0.11 + - numba>=0.49 + - *numpy + - ucx-py==0.31.* + - output_types: conda + packages: - ucx>=1.13.0 - - ucx-py=0.31.* - ucx-proc=*=gpu - - rmm=23.04 - - dask-cuda=23.04 - test_python: + - output_types: pyproject + packages: + - pylibraft==23.4.* + test_python_common: common: - - output_types: [conda, requirements] + - output_types: [conda, requirements, pyproject] packages: - - cupy - pytest - pytest-cov + test_pylibraft: + common: + - output_types: [conda, requirements, pyproject] + packages: + - cupy - scikit-learn - scipy diff --git a/docs/source/build.md b/docs/source/build.md index bbb454736a..e8e6ac8a14 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -75,7 +75,7 @@ Once installed, `libraft` headers (and dependencies which were downloaded and in ``` -### C++ Shared Libraries (optional) +### C++ Shared Library (optional) A shared library can be built for speeding up compile times. The shared library also contains a runtime API that allows you to invoke RAFT APIs directly from C++ source files (without `nvcc`). The shared library can also significantly improve re-compile times both while developing RAFT and using its APIs to develop applications. Pass the `--compile-lib` flag to `build.sh` to build the library: ```bash @@ -109,7 +109,7 @@ Compile the tests using the `tests` target in `build.sh`. Test compile times can be improved significantly by using the optional shared libraries. If installed, they will be used automatically when building the tests but `--compile-libs` can be used to add additional compilation units and compile them with the tests. ```bash -./build.sh libraft tests --compile-libs +./build.sh libraft tests --compile-lib ``` The tests are broken apart by algorithm category, so you will find several binaries in `cpp/build/` named `*_TEST`. @@ -151,19 +151,17 @@ make -j install RAFT's cmake has the following configurable flags available:. -| Flag | Possible Values | Default Value | Behavior | -| --- | --- | --- | --- | -| BUILD_TESTS | ON, OFF | ON | Compile Googletests | -| BUILD_BENCH | ON, OFF | OFF | Compile benchmarks | -| raft_FIND_COMPONENTS | nn distance | | Configures the optional components as a space-separated list | -| RAFT_COMPILE_LIBRARIES | ON, OFF | ON if either BUILD_TESTS or BUILD_BENCH is ON; otherwise OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | -| RAFT_COMPILE_NN_LIBRARY | ON, OFF | OFF | Compiles the `libraft-nn` shared library | -| RAFT_COMPILE_DIST_LIBRARY | ON, OFF | OFF | Compiles the `libraft-distance` shared library | -| DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | -| RAFT_NVTX | ON, OFF | OFF | Enable NVTX Markers | -| CUDA_ENABLE_KERNELINFO | ON, OFF | OFF | Enables `kernelinfo` in nvcc. This is useful for `compute-sanitizer` | -| CUDA_ENABLE_LINEINFO | ON, OFF | OFF | Enable the -lineinfo option for nvcc | -| CUDA_STATIC_RUNTIME | ON, OFF | OFF | Statically link the CUDA runtime | +| Flag | Possible Values | Default Value | Behavior | +|---------------------------|----------------------| --- | --- | +| BUILD_TESTS | ON, OFF | ON | Compile Googletests | +| BUILD_BENCH | ON, OFF | OFF | Compile benchmarks | +| raft_FIND_COMPONENTS | compiled distributed | | Configures the optional components as a space-separated list | +| RAFT_COMPILE_LIBRARY | ON, OFF | ON if either BUILD_TESTS or BUILD_BENCH is ON; otherwise OFF | Compiles all `libraft` shared libraries (these are required for Googletests) | +| DETECT_CONDA_ENV | ON, OFF | ON | Enable detection of conda environment for dependencies | +| RAFT_NVTX | ON, OFF | OFF | Enable NVTX Markers | +| CUDA_ENABLE_KERNELINFO | ON, OFF | OFF | Enables `kernelinfo` in nvcc. This is useful for `compute-sanitizer` | +| CUDA_ENABLE_LINEINFO | ON, OFF | OFF | Enable the -lineinfo option for nvcc | +| CUDA_STATIC_RUNTIME | ON, OFF | OFF | Statically link the CUDA runtime | Currently, shared libraries are provided for the `libraft-nn` and `libraft-distance` components. @@ -248,9 +246,9 @@ PROPERTIES CXX_STANDARD 17 ``` -### C++ header-only integration +### C++ header-only integration (without cmake) -When the needed [build dependencies](#build-dependencies) are already satisfied, RAFT can be trivially integrated into downstream projects by cloning the repository and adding `cpp/include` from RAFT to the include path: +While not a highly suggested method for building against RAFT, when all of the needed [build dependencies](#build-dependencies) are already satisfied, RAFT can be integrated into downstream projects by cloning the repository and adding `cpp/include` from RAFT to the include path: ```cmake set(RAFT_GIT_DIR ${CMAKE_CURRENT_BINARY_DIR}/raft CACHE STRING "Path to RAFT repo") ExternalProject_Add(raft @@ -262,8 +260,12 @@ ExternalProject_Add(raft INSTALL_COMMAND "") set(RAFT_INCLUDE_DIR ${RAFT_GIT_DIR}/raft/cpp/include CACHE STRING "RAFT include variable") ``` +### C++ header-only integration (with cmake) -If RAFT has already been installed, such as by using the `build.sh` script, use `find_package(raft)` and the `raft::raft` target. + +When using cmake, you can install RAFT headers into your environment with `./build.sh libraft`. + +If the RAFT headers have already been installed into your environment with cmake or through conda, such as by using the `build.sh` script, use `find_package(raft)` and the `raft::raft` target. ### Using C++ pre-compiled shared libraries @@ -271,17 +273,19 @@ Use `find_package(raft COMPONENTS compiled distributed)` to enable the shared li The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. -The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead: +The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead. RAFT's cmake creates a variable `RAFT_COMPILED` which can be used to ignore the pre-compiled template specializations only when the shared library has been enabled through cmake (such as by specifying the `compiled` component in `find_package`): ```c++ +#ifdef RAFT_COMPILED #include #include +#endif ``` ### Building RAFT C++ from source in cmake RAFT uses the [RAPIDS-CMake](https://github.com/rapidsai/rapids-cmake) library so it can be more easily included into downstream projects. RAPIDS cmake provides a convenience layer around the [CMake Package Manager (CPM)](https://github.com/cpm-cmake/CPM.cmake). -The following example is similar to invoking `find_package(raft)` but uses `rapids_cpm_find`, which provides a richer and more flexible configuration landscape by using CPM to fetch any dependencies not already available to the build. The `raft::raft` link target will be made available and it's recommended that it be used as a `PRIVATE` link dependency in downstream projects. The `COMPILE_LIBRARIES` option enables the building the shared libraries. +The following example is similar to invoking `find_package(raft)` but uses `rapids_cpm_find`, which provides a richer and more flexible configuration landscape by using CPM to fetch any dependencies not already available to the build. The `raft::raft` link target will be made available and it's recommended that it be used as a `PRIVATE` link dependency in downstream projects. The `COMPILE_LIBRARY` option enables the building the shared libraries. The following `cmake` snippet enables a flexible configuration of RAFT: @@ -292,19 +296,10 @@ set(RAFT_FORK "rapidsai") set(RAFT_PINNED_TAG "branch-${RAFT_VERSION}") function(find_and_configure_raft) - set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY CLONE_ON_PIN) + set(oneValueArgs VERSION FORK PINNED_TAG COMPILE_LIBRARY) cmake_parse_arguments(PKG "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN} ) - - #----------------------------------------------------- - # Clone RAFT locally if PINNED_TAG has been changed - #----------------------------------------------------- - if(PKG_CLONE_ON_PIN AND NOT PKG_PINNED_TAG STREQUAL "branch-${RAFT_VERSION}") - message("Pinned tag found: ${PKG_PINNED_TAG}. Cloning raft locally.") - set(CPM_DOWNLOAD_raft ON) - set(CMAKE_IGNORE_PATH "${CMAKE_INSTALL_PREFIX}/include/raft;${CMAKE_IGNORE_PATH}) - endif() - + #----------------------------------------------------- # Invoke CPM find_package() #----------------------------------------------------- @@ -332,15 +327,12 @@ endfunction() find_and_configure_raft(VERSION ${RAFT_VERSION}.00 FORK ${RAFT_FORK} PINNED_TAG ${RAFT_PINNED_TAG} - - # When PINNED_TAG above doesn't match cuml, - # force local raft clone in build directory - # even if it's already installed. - CLONE_ON_PIN ON COMPILE_LIBRARY NO ) ``` +You can find a fully-functioning [example template project](../../cpp/template/README.md) in the `cpp/template` directory, which provides everything you need to build a new application with RAFT or incorporate RAFT Into your existing libraries. + ## Uninstall Once built and installed, RAFT can be safely uninstalled using `build.sh` by specifying any or all of the installed components. Please note that since `pylibraft` depends on `libraft`, uninstalling `pylibraft` will also uninstall `libraft`: diff --git a/img/arch.png b/img/arch.png new file mode 100644 index 0000000000..ea9cad9204 Binary files /dev/null and b/img/arch.png differ diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index c7b42ecab7..f35a94bb9c 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -159,7 +159,6 @@ cdef device_matrix_view[float, int64_t, row_major] \ return make_device_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) - cdef device_matrix_view[uint8_t, int64_t, row_major] \ get_dmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: diff --git a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt index 98f0d7f67a..7b9c1591c1 100644 --- a/python/pylibraft/pylibraft/neighbors/CMakeLists.txt +++ b/python/pylibraft/pylibraft/neighbors/CMakeLists.txt @@ -13,7 +13,7 @@ # ============================================================================= # Set the list of Cython files to build -set(cython_sources common.pyx refine.pyx) +set(cython_sources common.pyx refine.pyx brute_force.pyx) set(linked_libraries raft::raft raft::compiled) # Build all of the Cython targets diff --git a/python/pylibraft/pylibraft/neighbors/__init__.py b/python/pylibraft/pylibraft/neighbors/__init__.py index f7510ba2db..a50b6f21a7 100644 --- a/python/pylibraft/pylibraft/neighbors/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/__init__.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. # + +from pylibraft.neighbors import brute_force + from .refine import refine -__all__ = ["common", "refine"] +__all__ = ["common", "refine", "brute_force"] diff --git a/python/pylibraft/pylibraft/neighbors/brute_force.pyx b/python/pylibraft/pylibraft/neighbors/brute_force.pyx new file mode 100644 index 0000000000..dbd888756d --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/brute_force.pyx @@ -0,0 +1,179 @@ +# +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +from cython.operator cimport dereference as deref +from libcpp cimport bool, nullptr +from libcpp.vector cimport vector + +from pylibraft.distance.distance_type cimport DistanceType + +from pylibraft.common import ( + DeviceResources, + auto_convert_output, + cai_wrapper, + device_ndarray, +) + +from libc.stdint cimport int64_t, uintptr_t + +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport get_dmv_float, get_dmv_int64 + +from pylibraft.common.handle import auto_sync_handle +from pylibraft.common.input_validation import is_c_contiguous +from pylibraft.common.interruptible import cuda_interruptible + +from pylibraft.distance.distance_type cimport DistanceType + +# TODO: Centralize this + +from pylibraft.distance.pairwise_distance import DISTANCE_TYPES + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.neighbors.cpp.brute_force cimport knn as c_knn + + +def _get_array_params(array_interface, check_dtype=None): + dtype = np.dtype(array_interface["typestr"]) + if check_dtype is None and dtype != check_dtype: + raise TypeError("dtype %s not supported" % dtype) + shape = array_interface["shape"] + if len(shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(shape)) + data = array_interface["data"][0] + return (shape, dtype, data) + + +@auto_sync_handle +@auto_convert_output +def knn(dataset, queries, k=None, indices=None, distances=None, + metric="sqeuclidean", metric_arg=2.0, + global_id_offset=0, handle=None): + """ + Perform a brute-force nearest neighbors search. + + Parameters + ---------- + dataset : array interface compliant matrix, row-major layout, + shape (n_samples, dim). Supported dtype [float] + queries : array interface compliant matrix, row-major layout, + shape (n_queries, dim) Supported dtype [float] + k : int + Number of neighbors to search (k <= 2048). Optional if indices or + distances arrays are given (in which case their second dimension + is k). + indices : Optional array interface compliant matrix shape + (n_queries, k), dtype int64_t. If supplied, neighbor + indices will be written here in-place. (default None) + Supported dtype uint64 + distances : Optional array interface compliant matrix shape + (n_queries, k), dtype float. If supplied, neighbor + indices will be written here in-place. (default None) + + {handle_docstring} + + Returns + ------- + indices: array interface compliant object containing resulting indices + shape (n_queries, k) + + distances: array interface compliant object containing resulting distances + shape (n_queries, k) + + Examples + -------- + + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors.brute_force import knn + + >>> n_samples = 50000 + >>> n_features = 50 + >>> n_queries = 1000 + + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + >>> # Search using the built index + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> k = 40 + >>> distances, neighbors = knn(dataset, queries, k) + >>> distances = cp.asarray(distances) + >>> neighbors = cp.asarray(neighbors) + """ + + if handle is None: + handle = DeviceResources() + + dataset_cai = cai_wrapper(dataset) + queries_cai = cai_wrapper(queries) + + if k is None: + if indices is not None: + k = cai_wrapper(indices).shape[1] + elif distances is not None: + k = cai_wrapper(distances).shape[1] + else: + raise ValueError("Argument k must be specified if both indices " + "and distances arg is None") + + n_queries = cai_wrapper(queries).shape[0] + + if indices is None: + indices = device_ndarray.empty((n_queries, k), dtype='int64') + + if distances is None: + distances = device_ndarray.empty((n_queries, k), dtype='float32') + + cdef DistanceType c_metric = DISTANCE_TYPES[metric] + + distances_cai = cai_wrapper(distances) + indices_cai = cai_wrapper(indices) + + cdef optional[float] c_metric_arg = metric_arg + cdef optional[int64_t] c_global_offset = global_id_offset + + cdef device_resources* handle_ = \ + handle.getHandle() + + if dataset_cai.dtype == np.float32: + with cuda_interruptible(): + c_knn(deref(handle_), + get_dmv_float(dataset_cai, check_shape=True), + get_dmv_float(queries_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), + get_dmv_float(distances_cai, check_shape=True), + c_metric, + c_metric_arg, + c_global_offset) + else: + raise TypeError("dtype %s not supported" % dataset_cai.dtype) + + return (distances, indices) diff --git a/python/pylibraft/pylibraft/neighbors/common.pyx b/python/pylibraft/pylibraft/neighbors/common.pyx index a8380b589b..24c1abcf18 100644 --- a/python/pylibraft/pylibraft/neighbors/common.pyx +++ b/python/pylibraft/pylibraft/neighbors/common.pyx @@ -22,13 +22,15 @@ import warnings from pylibraft.distance.distance_type cimport DistanceType +SUPPORTED_DISTANCES = { + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, + "inner_product": DistanceType.InnerProduct, + +} + def _get_metric(metric): - SUPPORTED_DISTANCES = { - "sqeuclidean": DistanceType.L2Expanded, - "euclidean": DistanceType.L2SqrtExpanded, - "inner_product": DistanceType.InnerProduct - } if metric not in SUPPORTED_DISTANCES: if metric == "l2_expanded": warnings.warn("Using l2_expanded as a metric name is deprecated," diff --git a/python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd b/python/pylibraft/pylibraft/neighbors/cpp/__init__.pxd new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/pylibraft/pylibraft/neighbors/cpp/__init__.py b/python/pylibraft/pylibraft/neighbors/cpp/__init__.py new file mode 100644 index 0000000000..a7e7b75096 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/__init__.py @@ -0,0 +1,14 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# diff --git a/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd new file mode 100644 index 0000000000..de5e0af267 --- /dev/null +++ b/python/pylibraft/pylibraft/neighbors/cpp/brute_force.pxd @@ -0,0 +1,55 @@ +# +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# cython: profile=False +# distutils: language = c++ +# cython: embedsignature = True +# cython: language_level = 3 + +import numpy as np + +import pylibraft.common.handle + +from cython.operator cimport dereference as deref +from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t +from libcpp cimport bool, nullptr +from libcpp.string cimport string +from libcpp.vector cimport vector + +from rmm._lib.memory_resource cimport device_memory_resource + +from pylibraft.common.cpp.mdspan cimport ( + device_matrix_view, + host_matrix_view, + make_device_matrix_view, + make_host_matrix_view, + row_major, +) +from pylibraft.common.cpp.optional cimport optional +from pylibraft.common.handle cimport device_resources +from pylibraft.distance.distance_type cimport DistanceType + + +cdef extern from "raft_runtime/neighbors/brute_force.hpp" \ + namespace "raft::runtime::neighbors::brute_force" nogil: + + cdef void knn(const device_resources & handle, + device_matrix_view[float, int64_t, row_major] index, + device_matrix_view[float, int64_t, row_major] search, + device_matrix_view[int64_t, int64_t, row_major] indices, + device_matrix_view[float, int64_t, row_major] distances, + DistanceType metric, + optional[float] metric_arg, + optional[int64_t] global_id_offset) except + diff --git a/python/pylibraft/pylibraft/test/test_brue_force.py b/python/pylibraft/pylibraft/test/test_brue_force.py new file mode 100644 index 0000000000..f349be892d --- /dev/null +++ b/python/pylibraft/pylibraft/test/test_brue_force.py @@ -0,0 +1,99 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest +from scipy.spatial.distance import cdist + +from pylibraft.common import DeviceResources, Stream, device_ndarray +from pylibraft.neighbors.brute_force import knn + + +@pytest.mark.parametrize("n_index_rows", [32, 100]) +@pytest.mark.parametrize("n_query_rows", [32, 100]) +@pytest.mark.parametrize("n_cols", [40, 100]) +@pytest.mark.parametrize("k", [1, 5, 32]) +@pytest.mark.parametrize( + "metric", + [ + "euclidean", + "cityblock", + "chebyshev", + "canberra", + "correlation", + "russellrao", + "cosine", + "sqeuclidean", + # "inner_product", + ], +) +@pytest.mark.parametrize("inplace", [True, False]) +@pytest.mark.parametrize("order", ["F", "C"]) +@pytest.mark.parametrize("dtype", [np.float32]) +def test_knn( + n_index_rows, n_query_rows, n_cols, k, inplace, metric, order, dtype +): + index = np.random.random_sample((n_index_rows, n_cols)).astype(dtype) + queries = np.random.random_sample((n_query_rows, n_cols)).astype(dtype) + + # RussellRao expects boolean arrays + if metric == "russellrao": + index[index < 0.5] = 0.0 + index[index >= 0.5] = 1.0 + queries[queries < 0.5] = 0.0 + queries[queries >= 0.5] = 1.0 + + indices = np.zeros((n_query_rows, k), dtype="int64") + distances = np.zeros((n_query_rows, k), dtype=dtype) + + index_device = device_ndarray(index) + + queries_device = device_ndarray(queries) + indices_device = device_ndarray(indices) + distances_device = device_ndarray(distances) + + s2 = Stream() + handle = DeviceResources(stream=s2) + ret_distances, ret_indices = knn( + index_device, + queries_device, + k, + indices=indices_device, + distances=distances_device, + metric=metric, + handle=handle, + ) + handle.sync() + + pw_dists = cdist(queries, index, metric=metric) + + distances_device = ret_distances if not inplace else distances_device + + actual_distances = distances_device.copy_to_host() + + actual_distances[actual_distances <= 1e-5] = 0.0 + argsort = np.argsort(pw_dists, axis=1) + + for i in range(pw_dists.shape[0]): + expected_indices = argsort[i] + gpu_dists = actual_distances[i] + + if metric == "correlation" or metric == "cosine": + gpu_dists = gpu_dists[::-1] + + cpu_ordered = pw_dists[i, expected_indices] + np.testing.assert_allclose( + cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 + ) diff --git a/python/pylibraft/pylibraft/test/test_doctests.py b/python/pylibraft/pylibraft/test/test_doctests.py index 3276ca115f..34be6c55f5 100644 --- a/python/pylibraft/pylibraft/test/test_doctests.py +++ b/python/pylibraft/pylibraft/test/test_doctests.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -96,6 +96,7 @@ def _find_doctests_in_obj(obj, finder=None, criteria=None): DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.distance)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.ivf_pq)) +DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.neighbors.brute_force)) DOC_STRINGS.extend(_find_doctests_in_obj(pylibraft.random)) diff --git a/python/pylibraft/pyproject.toml b/python/pylibraft/pyproject.toml index 7d92fd0763..fed15bbab0 100644 --- a/python/pylibraft/pyproject.toml +++ b/python/pylibraft/pyproject.toml @@ -15,15 +15,15 @@ [build-system] requires = [ - "wheel", - "setuptools", - "cython>=0.29,<0.30", - "cuda-python>=11.7.1,<12.0", - "scikit-build>=0.13.1", "cmake>=3.23.1,!=3.25.0", + "cuda-python >=11.7.1,<12.0", + "cython>=0.29,<0.30", "ninja", "rmm==23.4.*", -] + "scikit-build>=0.13.1", + "setuptools", + "wheel", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. build-backend = "setuptools.build_meta" [project] @@ -37,10 +37,10 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.8" dependencies = [ - "numpy", - "cuda-python>=11.7.1,<12.0", + "cuda-python >=11.7.1,<12.0", + "numpy>=1.21", "rmm==23.4.*", -] +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", @@ -50,10 +50,12 @@ classifiers = [ [project.optional-dependencies] test = [ + "cupy", "pytest", - "scipy", + "pytest-cov", "scikit-learn", -] + "scipy", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] Homepage = "https://github.com/rapidsai/raft" diff --git a/python/pylibraft/setup.cfg b/python/pylibraft/setup.cfg new file mode 100644 index 0000000000..7d1a0c9065 --- /dev/null +++ b/python/pylibraft/setup.cfg @@ -0,0 +1,38 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +[isort] +line_length=79 +multi_line_output=3 +include_trailing_comma=True +force_grid_wrap=0 +combine_as_imports=True +order_by_type=True +known_dask= + dask + distributed + dask_cuda +known_rapids= + nvtext + cudf + cuml + cugraph + dask_cudf + rmm +known_first_party= + raft + pylibraft +default_section=THIRDPARTY +sections=FUTURE,STDLIB,THIRDPARTY,DASK,RAPIDS,FIRSTPARTY,LOCALFOLDER +skip= + thirdparty + .eggs + .git + .hg + .mypy_cache + .tox + .venv + _build + buck-out + build + dist + __init__.py diff --git a/python/raft-dask/pyproject.toml b/python/raft-dask/pyproject.toml index 2fe6522f57..fe490ea117 100644 --- a/python/raft-dask/pyproject.toml +++ b/python/raft-dask/pyproject.toml @@ -15,13 +15,13 @@ [build-system] requires = [ - "wheel", - "setuptools", - "cython>=0.29,<0.30", - "scikit-build>=0.13.1", "cmake>=3.23.1,!=3.25.0", + "cython>=0.29,<0.30", "ninja", -] + "scikit-build>=0.13.1", + "setuptools", + "wheel", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project] name = "raft-dask" @@ -34,15 +34,15 @@ authors = [ license = { text = "Apache 2.0" } requires-python = ">=3.8" dependencies = [ - "numpy", - "numba>=0.49", - "joblib>=0.11", "dask-cuda==23.4.*", "dask>=2023.1.1", - "ucx-py==0.31.*", "distributed>=2023.1.1", + "joblib>=0.11", + "numba>=0.49", + "numpy>=1.21", "pylibraft==23.4.*", -] + "ucx-py==0.31.*", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. classifiers = [ "Intended Audience :: Developers", "Programming Language :: Python", @@ -53,8 +53,8 @@ classifiers = [ [project.optional-dependencies] test = [ "pytest", - "dask[distributed,dataframe]", -] + "pytest-cov", +] # This list was generated by `rapids-dependency-file-generator`. To make changes, edit ../../dependencies.yaml and run `rapids-dependency-file-generator`. [project.urls] Homepage = "https://github.com/rapidsai/raft" diff --git a/python/raft-dask/raft_dask/common/comms.py b/python/raft-dask/raft_dask/common/comms.py index 56e40b98da..ebe9a8dc4f 100644 --- a/python/raft-dask/raft_dask/common/comms.py +++ b/python/raft-dask/raft_dask/common/comms.py @@ -19,7 +19,7 @@ import warnings from collections import OrderedDict -from dask.distributed import default_client, get_worker +from dask.distributed import default_client from pylibraft.common.handle import Handle @@ -242,7 +242,7 @@ def destroy(self): self.ucx_initialized = False -def local_handle(sessionId): +def local_handle(sessionId, dask_worker=None): """ Simple helper function for retrieving the local handle_t instance for a comms session on a worker. @@ -251,16 +251,19 @@ def local_handle(sessionId): ---------- sessionId : str session identifier from an initialized comms instance + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- handle : raft.Handle or None """ - state = get_raft_comm_state(sessionId, get_worker()) + state = get_raft_comm_state(sessionId, dask_worker) return state["handle"] if "handle" in state else None -def get_raft_comm_state(sessionId, state_object=None): +def get_raft_comm_state(sessionId, state_object=None, dask_worker=None): """ Retrieves cuML comms state on the scheduler node, for the given sessionId, creating a new session if it does not exist. If no session id is given, @@ -271,13 +274,16 @@ def get_raft_comm_state(sessionId, state_object=None): sessionId : SessionId value to retrieve from the dask_scheduler instances state_object : Object (either Worker, or Scheduler) on which the raft comm state will retrieved (or created) + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Returns ------- session state : str session state associated with sessionId """ - state_object = state_object if state_object is not None else get_worker() + state_object = state_object if state_object is not None else dask_worker if not hasattr(state_object, "_raft_comm_state"): state_object._raft_comm_state = {} @@ -308,13 +314,19 @@ def set_nccl_root(sessionId, state_object): return raft_comm_state["nccl_uid"] -def get_ucx(): +def get_ucx(dask_worker=None): """ A simple convenience wrapper to make sure UCP listener and endpoints are only ever assigned once per worker. + + Parameters + ---------- + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ raft_comm_state = get_raft_comm_state( - sessionId="ucp", state_object=get_worker() + sessionId="ucp", state_object=dask_worker ) if "ucx" not in raft_comm_state: raft_comm_state["ucx"] = UCX.get() @@ -371,7 +383,7 @@ def _func_set_scheduler_as_nccl_root(sessionId, verbose, dask_scheduler): return nccl_uid -def _func_set_worker_as_nccl_root(sessionId, verbose): +def _func_set_worker_as_nccl_root(sessionId, verbose, dask_worker=None): """ Creates a persistent nccl uniqueId on the scheduler node. @@ -380,63 +392,74 @@ def _func_set_worker_as_nccl_root(sessionId, verbose): ---------- sessionId : Associated session to attach the unique ID to. verbose : Indicates whether or not to emit additional information + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) Return ------ uniqueId : byte str NCCL uniqueId, associating this DASK worker as its root node. """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Setting worker as NCCL root for session, '{sessionId}'", ) - nccl_uid = set_nccl_root(sessionId=sessionId, state_object=worker) + nccl_uid = set_nccl_root(sessionId=sessionId, state_object=dask_worker) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Done setting scheduler as NCCL root." ) return nccl_uid -def _func_ucp_listener_port(): - return get_ucx().listener_port() +def _func_ucp_listener_port(dask_worker=None): + return get_ucx(dask_worker=dask_worker).listener_port() async def _func_init_all( - sessionId, uniqueId, comms_p2p, worker_info, verbose, streams_per_handle + sessionId, + uniqueId, + comms_p2p, + worker_info, + verbose, + streams_per_handle, + dask_worker=None, ): - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId - raft_comm_state["wid"] = worker_info[get_worker().address]["rank"] + raft_comm_state["wid"] = worker_info[dask_worker.address]["rank"] raft_comm_state["nworkers"] = len(worker_info) if verbose: - worker.log_event(topic="info", msg="Initializing NCCL.") + dask_worker.log_event(topic="info", msg="Initializing NCCL.") start = time.time() - _func_init_nccl(sessionId, uniqueId) + _func_init_nccl(sessionId, uniqueId, dask_worker=dask_worker) if verbose: elapsed = time.time() - start - worker.log_event( + dask_worker.log_event( topic="info", msg=f"NCCL Initialization took: {elapsed} seconds." ) if comms_p2p: if verbose: - worker.log_event(topic="info", msg="Initializing UCX Endpoints") + dask_worker.log_event( + topic="info", msg="Initializing UCX Endpoints" + ) if verbose: start = time.time() - await _func_ucp_create_endpoints(sessionId, worker_info) + await _func_ucp_create_endpoints( + sessionId, worker_info, dask_worker=dask_worker + ) if verbose: elapsed = time.time() - start @@ -444,18 +467,22 @@ async def _func_init_all( f"Done initializing UCX endpoints." f"Took: {elapsed} seconds.\nBuilding handle." ) - worker.log_event(topic="info", msg=msg) + dask_worker.log_event(topic="info", msg=msg) - _func_build_handle_p2p(sessionId, streams_per_handle, verbose) + _func_build_handle_p2p( + sessionId, streams_per_handle, verbose, dask_worker=dask_worker + ) if verbose: - worker.log_event(topic="info", msg="Done building handle.") + dask_worker.log_event(topic="info", msg="Done building handle.") else: - _func_build_handle(sessionId, streams_per_handle, verbose) + _func_build_handle( + sessionId, streams_per_handle, verbose, dask_worker=dask_worker + ) -def _func_init_nccl(sessionId, uniqueId): +def _func_init_nccl(sessionId, uniqueId, dask_worker=None): """ Initialize ncclComm_t on worker @@ -466,11 +493,13 @@ def _func_init_nccl(sessionId, uniqueId): uniqueId : array[byte] The NCCL unique Id generated from the client. + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker, dask_worker=dask_worker ) wid = raft_comm_state["wid"] nWorkers = raft_comm_state["nworkers"] @@ -480,13 +509,15 @@ def _func_init_nccl(sessionId, uniqueId): n.init(nWorkers, uniqueId, wid) raft_comm_state["nccl"] = n except Exception as e: - worker.log_event( + dask_worker.log_event( topic="error", msg=f"An error occurred initializing NCCL: {e}." ) raise -def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): +def _func_build_handle_p2p( + sessionId, streams_per_handle, verbose, dask_worker=None +): """ Builds a handle_t on the current worker given the initialized comms @@ -495,14 +526,16 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event(topic="info", msg="Building p2p handle.") + dask_worker.log_event(topic="info", msg="Building p2p handle.") - ucp_worker = get_ucx().get_worker() + ucp_worker = get_ucx(dask_worker).get_worker() raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) handle = Handle(n_streams=streams_per_handle) @@ -512,21 +545,23 @@ def _func_build_handle_p2p(sessionId, streams_per_handle, verbose): workerId = raft_comm_state["wid"] if verbose: - worker.log_event(topic="info", msg="Injecting comms on handle.") + dask_worker.log_event(topic="info", msg="Injecting comms on handle.") inject_comms_on_handle( handle, nccl_comm, ucp_worker, eps, nWorkers, workerId, verbose ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) raft_comm_state["handle"] = handle -def _func_build_handle(sessionId, streams_per_handle, verbose): +def _func_build_handle( + sessionId, streams_per_handle, verbose, dask_worker=None +): """ Builds a handle_t on the current worker given the initialized comms @@ -535,17 +570,19 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): sessionId : str id to reference state for current comms instance. streams_per_handle : int number of internal streams to create verbose : bool print verbose logging output + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ - worker = get_worker() if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg="Finished injecting comms on handle." ) handle = Handle(n_streams=streams_per_handle) raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) workerId = raft_comm_state["wid"] @@ -558,16 +595,18 @@ def _func_build_handle(sessionId, streams_per_handle, verbose): raft_comm_state["handle"] = handle -def _func_store_initial_state(nworkers, sessionId, uniqueId, wid): +def _func_store_initial_state( + nworkers, sessionId, uniqueId, wid, dask_worker=None +): raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["nccl_uid"] = uniqueId raft_comm_state["wid"] = wid raft_comm_state["nworkers"] = nworkers -async def _func_ucp_create_endpoints(sessionId, worker_info): +async def _func_ucp_create_endpoints(sessionId, worker_info, dask_worker): """ Runs on each worker to create ucp endpoints to all other workers @@ -577,6 +616,9 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): uuid unique id for this instance worker_info : dict Maps worker addresses to NCCL ranks & UCX ports + dask_worker : dask_worker object + (Note: if called by client.run(), this is supplied by Dask + and not the client) """ eps = [None] * len(worker_info) count = 1 @@ -584,40 +626,47 @@ async def _func_ucp_create_endpoints(sessionId, worker_info): for k in worker_info: ip, port = parse_host_port(k) - ep = await get_ucx().get_endpoint(ip, worker_info[k]["port"]) + ep = await get_ucx(dask_worker=dask_worker).get_endpoint( + ip, worker_info[k]["port"] + ) eps[worker_info[k]["rank"]] = ep count += 1 raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=get_worker() + sessionId=sessionId, state_object=dask_worker ) raft_comm_state["ucp_eps"] = eps -async def _func_destroy_all(sessionId, comms_p2p, verbose=False): - worker = get_worker() +async def _func_destroy_all( + sessionId, comms_p2p, verbose=False, dask_worker=None +): if verbose: - worker.log_event(topic="info", msg="Destroying NCCL session state.") + dask_worker.log_event( + topic="info", msg="Destroying NCCL session state." + ) raft_comm_state = get_raft_comm_state( - sessionId=sessionId, state_object=worker + sessionId=sessionId, state_object=dask_worker ) if "nccl" in raft_comm_state: raft_comm_state["nccl"].destroy() del raft_comm_state["nccl"] if verbose: - worker.log_event(topic="info", msg="NCCL session state destroyed.") + dask_worker.log_event( + topic="info", msg="NCCL session state destroyed." + ) else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'nccl' element", ) if verbose: - worker.log_event( + dask_worker.log_event( topic="info", msg=f"Destroying CUDA handle for sessionId, '{sessionId}.'", ) @@ -626,7 +675,7 @@ async def _func_destroy_all(sessionId, comms_p2p, verbose=False): del raft_comm_state["handle"] else: if verbose: - worker.log_event( + dask_worker.log_event( topic="warning", msg=f"Session state for, '{sessionId}', " f"does not contain expected 'handle' element", diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 74ec446e94..3a430f9270 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2022, NVIDIA CORPORATION. +# Copyright (c) 2019-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ import pytest -from dask.distributed import Client, wait +from dask.distributed import Client, get_worker, wait try: from raft_dask.common import ( @@ -60,32 +60,32 @@ def test_comms_init_no_p2p(cluster): def func_test_collective(func, sessionId, root): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return func(handle, root) def func_test_send_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_send_recv(handle, n_trials) def func_test_device_send_or_recv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_send_or_recv(handle, n_trials) def func_test_device_sendrecv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_sendrecv(handle, n_trials) def func_test_device_multicast_sendrecv(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comms_device_multicast_sendrecv(handle, n_trials) def func_test_comm_split(sessionId, n_trials): - handle = local_handle(sessionId) + handle = local_handle(sessionId, dask_worker=get_worker()) return perform_test_comm_split(handle, n_trials) @@ -114,11 +114,9 @@ def func_check_uid_on_scheduler(sessionId, uniqueId, dask_scheduler): ) -def func_check_uid_on_worker(sessionId, uniqueId): - from dask.distributed import get_worker - +def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): return func_check_uid( - sessionId=sessionId, uniqueId=uniqueId, state_object=get_worker() + sessionId=sessionId, uniqueId=uniqueId, state_object=dask_worker ) @@ -127,7 +125,7 @@ def test_handles(cluster): client = Client(cluster) def _has_handle(sessionId): - return local_handle(sessionId) is not None + return local_handle(sessionId, dask_worker=get_worker()) is not None try: cb = Comms(verbose=True) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..e64641d05b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,55 @@ +# Copyright (c) 2022-2023, NVIDIA CORPORATION. + +[flake8] +filename = *.py, *.pyx, *.pxd, *.pxi +exclude = __init__.py, *.egg, build, docs, .git +force-check = True +ignore = + # line break before binary operator + W503, + # whitespace before : + E203 +per-file-ignores = + # Rules ignored only in Cython: + # E211: whitespace before '(' (used in multi-line imports) + # E225: Missing whitespace around operators (breaks cython casting syntax like ) + # E226: Missing whitespace around arithmetic operators (breaks cython pointer syntax like int*) + # E227: Missing whitespace around bitwise or shift operator (Can also break casting syntax) + # E275: Missing whitespace after keyword (Doesn't work with Cython except?) + # E402: invalid syntax (works for Python, not Cython) + # E999: invalid syntax (works for Python, not Cython) + # W504: line break after binary operator (breaks lines that end with a pointer) + *.pyx: E211, E225, E226, E227, E275, E402, E999, W504 + *.pxd: E211, E225, E226, E227, E275, E402, E999, W504 + *.pxi: E211, E225, E226, E227, E275, E402, E999, W504 + +[pydocstyle] +# Due to https://github.com/PyCQA/pydocstyle/issues/363, we must exclude rather +# than include using match-dir. Note that as discussed in +# https://stackoverflow.com/questions/65478393/how-to-filter-directories-using-the-match-dir-flag-for-pydocstyle, +# unlike the match option above this match-dir will have no effect when +# pydocstyle is invoked from pre-commit. Therefore this exclusion list must +# also be maintained in the pre-commit config file. +match-dir = ^(?!(ci|cpp|conda|docs|java|notebooks)).*$ +# Allow missing docstrings for docutils +ignore-decorators = .*(docutils|doc_apply|copy_docstring).* +select = + D201, D204, D206, D207, D208, D209, D210, D211, D214, D215, D300, D301, D302, D403, D405, D406, D407, D408, D409, D410, D411, D412, D414, D418 + # Would like to enable the following rules in the future: + # D200, D202, D205, D400 + +[mypy] +ignore_missing_imports = True +# If we don't specify this, then mypy will check excluded files if +# they are imported by a checked file. +follow_imports = skip + +[codespell] +# note: pre-commit passes explicit lists of files here, which this skip file list doesn't override - +# this is only to allow you to run codespell interactively +skip = ./.git,./.github,./cpp/build,.*egg-info.*,./.mypy_cache,.*_skbuild +# ignore short words, and typename parameters like OffsetT +ignore-regex = \b(.{1,4}|[A-Z]\w*T)\b +ignore-words-list = inout,unparseable,numer +builtin = clear +quiet-level = 3