diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 620a13fe17..8f745848e0 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -28,7 +28,7 @@ concurrency: jobs: cpp-build: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-matrix-build.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -37,7 +37,7 @@ jobs: python-build: needs: [cpp-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-matrix-build.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -46,9 +46,52 @@ jobs: upload-conda: needs: [cpp-build, python-build] secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-upload-packages.yaml@cuda-118 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} + wheel-build-pylibraft: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: pylibraft + package-dir: python/pylibraft + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-publish-pylibraft: + needs: wheel-build-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@cuda-118 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: pylibraft + wheel-build-raft-dask: + needs: wheel-publish-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: raft_dask + package-dir: python/raft-dask + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-publish-raft-dask: + needs: wheel-build-raft-dask + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@cuda-118 + with: + build_type: ${{ inputs.build_type || 'branch' }} + branch: ${{ inputs.branch }} + sha: ${{ inputs.sha }} + date: ${{ inputs.date }} + package-name: raft_dask diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index ca2e2356c0..b705557795 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -17,33 +17,80 @@ jobs: - conda-cpp-tests - conda-python-build - conda-python-tests + - wheel-build-pylibraft + - wheel-tests-pylibraft + - wheel-build-raft-dask + - wheel-tests-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/pr-builder.yaml@cuda-118 checks: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/checks.yaml@cuda-118 conda-cpp-build: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-matrix-build.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-build.yaml@cuda-118 with: build_type: pull-request node_type: cpu16 conda-cpp-tests: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@cuda-118 with: build_type: pull-request conda-python-build: needs: conda-cpp-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-matrix-build.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-build.yaml@cuda-118 with: build_type: pull-request conda-python-tests: needs: conda-python-build secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@cuda-118 with: build_type: pull-request + wheel-build-pylibraft: + needs: checks + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 + with: + build_type: pull-request + package-name: pylibraft + package-dir: python/pylibraft + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-tests-pylibraft: + needs: wheel-build-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 + with: + build_type: pull-request + package-name: pylibraft + test-before-amd64: "pip install cupy-cuda11x" + # On arm also need to install cupy from the specific webpage. + test-before-arm64: "pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64" + test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + test-smoketest: "python ./ci/wheel_smoke_test_pylibraft.py" + wheel-build-raft-dask: + needs: wheel-tests-pylibraft + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@cuda-118 + with: + build_type: pull-request + package-name: raft_dask + package-dir: python/raft-dask + before-wheel: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-wheelhouse" + skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + wheel-tests-raft-dask: + needs: wheel-build-raft-dask + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 + with: + build_type: pull-request + package-name: raft_dask + # Always want to test against latest dask/distributed. + test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_cu11 rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" + test-smoketest: "python ./ci/wheel_smoke_test_raft_dask.py" diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index bd201e987f..d41a660c6d 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -16,7 +16,7 @@ on: jobs: conda-cpp-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-cpp-tests.yaml@cuda-118 with: build_type: nightly branch: ${{ inputs.branch }} @@ -24,9 +24,33 @@ jobs: sha: ${{ inputs.sha }} conda-python-tests: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@main + uses: rapidsai/shared-action-workflows/.github/workflows/conda-python-tests.yaml@cuda-118 with: build_type: nightly branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} + wheel-tests-pylibraft: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + package-name: pylibraft + test-before-amd64: "pip install cupy-cuda11x" + test-before-arm64: "pip install cupy-cuda11x -f https://pip.cupy.dev/aarch64" + test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + wheel-tests-raft-dask: + secrets: inherit + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@cuda-118 + with: + build_type: nightly + branch: ${{ inputs.branch }} + date: ${{ inputs.date }} + sha: ${{ inputs.sha }} + package-name: raft_dask + test-before-amd64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-before-arm64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.02" + test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" diff --git a/.github/workflows/wheels.yml b/.github/workflows/wheels.yml deleted file mode 100644 index 0a681b864b..0000000000 --- a/.github/workflows/wheels.yml +++ /dev/null @@ -1,72 +0,0 @@ -name: RAFT wheels - -on: - workflow_call: - inputs: - versioneer-override: - type: string - default: '' - build-tag: - type: string - default: '' - branch: - required: true - type: string - date: - required: true - type: string - sha: - required: true - type: string - build-type: - type: string - default: nightly - -concurrency: - group: "raft-${{ github.workflow }}-${{ github.ref }}" - cancel-in-progress: true - -jobs: - pylibraft-wheel: - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux.yml@main - with: - repo: rapidsai/raft - - build-type: ${{ inputs.build-type }} - branch: ${{ inputs.branch }} - sha: ${{ inputs.sha }} - date: ${{ inputs.date }} - - package-dir: python/pylibraft - package-name: pylibraft - - python-package-versioneer-override: ${{ inputs.versioneer-override }} - python-package-build-tag: ${{ inputs.build-tag }} - - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" - - test-extras: test - test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" - secrets: inherit - raft-dask-wheel: - needs: pylibraft-wheel - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux.yml@main - with: - repo: rapidsai/raft - - build-type: ${{ inputs.build-type }} - branch: ${{ inputs.branch }} - sha: ${{ inputs.sha }} - date: ${{ inputs.date }} - - package-dir: python/raft-dask - package-name: raft_dask - - python-package-versioneer-override: ${{ inputs.versioneer-override }} - python-package-build-tag: ${{ inputs.build-tag }} - - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" - - test-extras: test - test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" - secrets: inherit diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index f48dff11cd..b766bfc066 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -2,7 +2,7 @@ repos: - repo: https://github.com/PyCQA/isort - rev: 5.10.1 + rev: 5.12.0 hooks: - id: isort # Use the config file specific to each subproject so that each diff --git a/README.md b/README.md index 8e0da6cd6d..ccd0df4926 100755 --- a/README.md +++ b/README.md @@ -25,8 +25,8 @@ While not exhaustive, the following general categories help summarize the accele | Category | Examples | | --- | --- | | **Data Formats** | sparse & dense, conversions, data generation | -| **Dense Operations** | linear algebra, matrix and vector operations, slicing, norms, factorization, least squares, svd & eigenvalue problems | -| **Sparse Operations** | linear algebra, eigenvalue problems, slicing, symmetrization, components & labeling | +| **Dense Operations** | linear algebra, matrix and vector operations, reductions, slicing, norms, factorization, least squares, svd & eigenvalue problems | +| **Sparse Operations** | linear algebra, eigenvalue problems, slicing, norms, reductions, factorization, symmetrization, components & labeling | | **Spatial** | pairwise distances, nearest neighbors, neighborhood graph construction | | **Basic Clustering** | spectral clustering, hierarchical clustering, k-means | | **Solvers** | combinatorial optimization, iterative solvers | @@ -65,17 +65,17 @@ auto matrix = raft::make_device_matrix(handle, n_rows, n_cols); ### C++ Example -Most of the primitives in RAFT accept a `raft::handle_t` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`. +Most of the primitives in RAFT accept a `raft::device_resources` object for the management of resources which are expensive to create, such CUDA streams, stream pools, and handles to other CUDA libraries like `cublas` and `cusolver`. The example below demonstrates creating a RAFT handle and using it with `device_matrix` and `device_vector` to allocate memory, generating random clusters, and computing pairwise Euclidean distances: ```c++ -#include +#include #include #include #include -raft::handle_t handle; +raft::device_resources handle; int n_samples = 5000; int n_features = 50; @@ -93,12 +93,12 @@ raft::distance::pairwise_distance(handle, input.view(), input.view(), output.vie It's also possible to create `raft::device_mdspan` views to invoke the same API with raw pointers and shape information: ```c++ -#include +#include #include #include #include -raft::handle_t handle; +raft::device_resources handle; int n_samples = 5000; int n_features = 50; @@ -277,7 +277,7 @@ Several CMake targets can be made available by adding components in the table be The easiest way to build RAFT from source is to use the `build.sh` script at the root of the repository: 1. Create an environment with the needed dependencies: ``` -mamba env create --name raft_dev_env -f conda/environments/all_cuda-115_arch-x86_64.yaml +mamba env create --name raft_dev_env -f conda/environments/all_cuda-118_arch-x86_64.yaml mamba activate raft_dev_env ``` ``` @@ -315,6 +315,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `solver`: Sparse solvers for optimization and approximation - `stats`: Moments, summary statistics, model performance measures - `util`: Various reusable tools and utilities for accelerated algorithm development + - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - `src`: Compiled APIs and template specializations for the shared libraries - `test`: Googletests source code diff --git a/build.sh b/build.sh index 34dcd3a2db..b47e1ed862 100755 --- a/build.sh +++ b/build.sh @@ -1,6 +1,6 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # raft build script @@ -153,6 +153,7 @@ function limitTests { # Remove the full LIMIT_TEST_TARGETS argument from list of args so that it passes validArgs function ARGS=${ARGS//--limit-tests=$LIMIT_TEST_TARGETS/} TEST_TARGETS=${LIMIT_TEST_TARGETS} + echo "Limiting tests to $TEST_TARGETS" fi fi } @@ -387,7 +388,7 @@ if (( ${NUMARGS} == 0 )) || hasArg libraft || hasArg docs || hasArg tests || has RAFT_CMAKE_CUDA_ARCHITECTURES="NATIVE" echo "Building for the architecture of the GPU in the system..." else - RAFT_CMAKE_CUDA_ARCHITECTURES="ALL" + RAFT_CMAKE_CUDA_ARCHITECTURES="RAPIDS" echo "Building for *ALL* supported GPU architectures..." fi diff --git a/ci/checks/copyright.py b/ci/checks/copyright.py index bfef5392f5..43a4a186f8 100644 --- a/ci/checks/copyright.py +++ b/ci/checks/copyright.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ re.compile(r"setup[.]cfg$"), re.compile(r"meta[.]yaml$") ] -ExemptFiles = ["cpp/include/raft/spatial/knn/detail/warp_select_faiss.cuh"] +ExemptFiles = ["cpp/include/raft/spatial/knn/detail/faiss_select/"] # this will break starting at year 10000, which is probably OK :) CheckSimple = re.compile( diff --git a/ci/cpu/build.sh b/ci/cpu/build.sh index 2f0e2b94ca..657126fdf0 100755 --- a/ci/cpu/build.sh +++ b/ci/cpu/build.sh @@ -43,7 +43,7 @@ export CMAKE_GENERATOR="Ninja" export CONDA_BLD_DIR="${WORKSPACE}/.conda-bld" # ucx-py version -export UCX_PY_VERSION='0.30.*' +export UCX_PY_VERSION='0.31.*' ################################################################################ # SETUP - Check environment diff --git a/ci/gpu/build.sh b/ci/gpu/build.sh index 1808480d37..84026203fa 100644 --- a/ci/gpu/build.sh +++ b/ci/gpu/build.sh @@ -38,7 +38,7 @@ export MINOR_VERSION=`echo $GIT_DESCRIBE_TAG | grep -o -E '([0-9]+\.[0-9]+)'` unset GIT_DESCRIBE_TAG # ucx-py version -export UCX_PY_VERSION='0.30.*' +export UCX_PY_VERSION='0.31.*' # Whether to install dask nightly or stable packages. export INSTALL_DASK_MAIN=1 diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index 0b6410f9c9..00f6905032 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. ######################## # RAFT Version Updater # ######################## @@ -17,12 +17,14 @@ CURRENT_MAJOR=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[1]}') CURRENT_MINOR=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[2]}') CURRENT_PATCH=$(echo $CURRENT_TAG | awk '{split($0, a, "."); print a[3]}') CURRENT_SHORT_TAG=${CURRENT_MAJOR}.${CURRENT_MINOR} +CURRENT_UCX_PY_VERSION="$(curl -sL https://version.gpuci.io/rapids/${CURRENT_SHORT_TAG}).*" #Get . for next version NEXT_MAJOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[1]}') NEXT_MINOR=$(echo $NEXT_FULL_TAG | awk '{split($0, a, "."); print a[2]}') NEXT_SHORT_TAG=${NEXT_MAJOR}.${NEXT_MINOR} -NEXT_UCX_PY_VERSION="$(curl -sL https://version.gpuci.io/rapids/${NEXT_SHORT_TAG}).*" +NEXT_UCX_PY_SHORT_TAG="$(curl -sL https://version.gpuci.io/rapids/${NEXT_SHORT_TAG})" +NEXT_UCX_PY_VERSION="${NEXT_UCX_PY_SHORT_TAG}.*" echo "Preparing release $CURRENT_TAG => $NEXT_FULL_TAG" @@ -53,3 +55,17 @@ done sed_runner "s/export UCX_PY_VERSION=.*/export UCX_PY_VERSION='${NEXT_UCX_PY_VERSION}'/g" ci/gpu/build.sh sed_runner "s/export UCX_PY_VERSION=.*/export UCX_PY_VERSION='${NEXT_UCX_PY_VERSION}'/g" ci/cpu/build.sh sed_runner "/^ucx_py_version:$/ {n;s/.*/ - \"${NEXT_UCX_PY_VERSION}\"/}" conda/recipes/raft-dask/conda_build_config.yaml + +# Wheel builds install dask-cuda from source, update its branch +sed_runner "s/dask-cuda.git@branch-[^\"\s]\+/dask-cuda.git@branch-${NEXT_SHORT_TAG}/g" .github/workflows/*.yaml + +# Need to distutils-normalize the original version +NEXT_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_SHORT_TAG}'))") +NEXT_UCX_PY_SHORT_TAG_PEP440=$(python -c "from setuptools.extern import packaging; print(packaging.version.Version('${NEXT_UCX_PY_SHORT_TAG}'))") + +# Wheel builds install intra-RAPIDS dependencies from same release +sed_runner "s/{cuda_suffix}[^\"].*\",/{cuda_suffix}==${NEXT_SHORT_TAG_PEP440}.*\",/g" python/pylibraft/setup.py +sed_runner "s/{cuda_suffix}.*\"\]/{cuda_suffix}==${NEXT_SHORT_TAG_PEP440}.*\"\]/g" python/pylibraft/_custom_build/backend.py +sed_runner "s/dask-cuda==.*\",/dask-cuda==${NEXT_SHORT_TAG_PEP440}.*\",/g" python/raft-dask/setup.py +sed_runner "s/pylibraft{cuda_suffix}.*\",/pylibraft{cuda_suffix}==${NEXT_SHORT_TAG_PEP440}.*\",/g" python/raft-dask/setup.py +sed_runner "s/ucx-py{cuda_suffix}.*\",/ucx-py{cuda_suffix}==${NEXT_UCX_PY_SHORT_TAG_PEP440}.*\",/g" python/raft-dask/setup.py diff --git a/ci/wheel_smoke_test_pylibraft.py b/ci/wheel_smoke_test_pylibraft.py new file mode 100644 index 0000000000..7fee674691 --- /dev/null +++ b/ci/wheel_smoke_test_pylibraft.py @@ -0,0 +1,38 @@ +import numpy as np +from scipy.spatial.distance import cdist + +from pylibraft.common import Handle, Stream, device_ndarray +from pylibraft.distance import pairwise_distance + + +if __name__ == "__main__": + metric = "euclidean" + n_rows = 1337 + n_cols = 1337 + + input1 = np.random.random_sample((n_rows, n_cols)) + input1 = np.asarray(input1, order="C").astype(np.float64) + + output = np.zeros((n_rows, n_rows), dtype=np.float64) + + expected = cdist(input1, input1, metric) + + expected[expected <= 1e-5] = 0.0 + + input1_device = device_ndarray(input1) + output_device = None + + s2 = Stream() + handle = Handle(stream=s2) + ret_output = pairwise_distance( + input1_device, input1_device, output_device, metric, handle=handle + ) + handle.sync() + + output_device = ret_output + + actual = output_device.copy_to_host() + + actual[actual <= 1e-5] = 0.0 + + assert np.allclose(expected, actual, rtol=1e-4) diff --git a/ci/wheel_smoke_test_raft_dask.py b/ci/wheel_smoke_test_raft_dask.py new file mode 100644 index 0000000000..32c13e61ca --- /dev/null +++ b/ci/wheel_smoke_test_raft_dask.py @@ -0,0 +1,92 @@ +from dask.distributed import Client, wait +from dask_cuda import LocalCUDACluster, initialize + +from raft_dask.common import ( + Comms, + local_handle, + perform_test_comm_split, + perform_test_comms_allgather, + perform_test_comms_allreduce, + perform_test_comms_bcast, + perform_test_comms_device_multicast_sendrecv, + perform_test_comms_device_send_or_recv, + perform_test_comms_device_sendrecv, + perform_test_comms_gather, + perform_test_comms_gatherv, + perform_test_comms_reduce, + perform_test_comms_reducescatter, + perform_test_comms_send_recv, +) + +import os +os.environ["UCX_LOG_LEVEL"] = "error" + + +def func_test_send_recv(sessionId, n_trials): + handle = local_handle(sessionId) + return perform_test_comms_send_recv(handle, n_trials) + + +def func_test_collective(func, sessionId, root): + handle = local_handle(sessionId) + return func(handle, root) + + +if __name__ == "__main__": + # initial setup + cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + client = Client(cluster) + + n_trials = 5 + root_location = "client" + + # p2p test for ucx + cb = Comms(comms_p2p=True, verbose=True) + cb.init() + + dfs = [ + client.submit( + func_test_send_recv, + cb.sessionId, + n_trials, + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + + wait(dfs, timeout=5) + + assert list(map(lambda x: x.result(), dfs)) + + cb.destroy() + + # collectives test for nccl + + cb = Comms( + verbose=True, client=client, nccl_root_location=root_location + ) + cb.init() + + for k, v in cb.worker_info(cb.worker_addresses).items(): + + dfs = [ + client.submit( + func_test_collective, + perform_test_comms_allgather, + cb.sessionId, + v["rank"], + pure=False, + workers=[w], + ) + for w in cb.worker_addresses + ] + wait(dfs, timeout=5) + + assert all([x.result() for x in dfs]) + + cb.destroy() + + # final client and cluster teardown + client.close() + cluster.close() diff --git a/conda/environments/all_cuda-115_arch-x86_64.yaml b/conda/environments/all_cuda-118_arch-x86_64.yaml similarity index 56% rename from conda/environments/all_cuda-115_arch-x86_64.yaml rename to conda/environments/all_cuda-118_arch-x86_64.yaml index 18e0a8187f..f194b152a6 100644 --- a/conda/environments/all_cuda-115_arch-x86_64.yaml +++ b/conda/environments/all_cuda-118_arch-x86_64.yaml @@ -12,37 +12,37 @@ dependencies: - clang-tools=11.1.0 - clang=11.1.0 - cmake>=3.23.1,!=3.25.0 -- cuda-profiler-api>=11.4.240,<=11.8.86 +- cuda-profiler-api=11.8.86 - cuda-python >=11.7.1,<12.0 -- cudatoolkit=11.5 +- cudatoolkit=11.8 - cupy - cxx-compiler - cython>=0.29,<0.30 -- dask-cuda=23.02.* +- dask-cuda=23.04 - dask>=2022.12.0 - distributed>=2022.12.0 - doxygen>=1.8.20 - faiss-proc=*=cuda -- gcc_linux-64=9.* -- libcublas-dev>=11.7.3.1,<=11.7.4.6 -- libcublas>=11.7.3.1,<=11.7.4.6 -- libcurand-dev>=10.2.6.48,<=10.2.7.107 -- libcurand>=10.2.6.48,<=10.2.7.107 -- libcusolver-dev>=11.2.1.48,<=11.3.2.107 -- libcusolver>=11.2.1.48,<=11.3.2.107 -- libcusparse-dev>=11.7.0.31,<=11.7.0.107 -- libcusparse>=11.7.0.31,<=11.7.0.107 -- libfaiss>=1.7.0=cuda* +- gcc_linux-64=9 +- libcublas-dev=11.11.3.6 +- libcublas=11.11.3.6 +- libcurand-dev=10.3.0.86 +- libcurand=10.3.0.86 +- libcusolver-dev=11.4.1.48 +- libcusolver=11.4.1.48 +- libcusparse-dev=11.7.5.86 +- libcusparse=11.7.5.86 +- libfaiss>=1.7.1=cuda* - ninja - pytest - pytest-cov -- rmm=23.02.* +- rmm=23.04 - scikit-build>=0.13.1 - scikit-learn - scipy - sphinx-markdown-tables - sysroot_linux-64==2.17 - ucx-proc=*=gpu -- ucx-py=0.30.* +- ucx-py=0.31.* - ucx>=1.13.0 -name: all_cuda-115_arch-x86_64 +name: all_cuda-118_arch-x86_64 diff --git a/conda/recipes/libraft/conda_build_config.yaml b/conda/recipes/libraft/conda_build_config.yaml index 399dd198eb..1012bddb40 100644 --- a/conda/recipes/libraft/conda_build_config.yaml +++ b/conda/recipes/libraft/conda_build_config.yaml @@ -20,42 +20,42 @@ gtest_version: - "=1.10.0" libfaiss_version: - - "1.7.0 *_cuda" + - "1.7.2 *_cuda" # The CTK libraries below are missing from the conda-forge::cudatoolkit -# package. The "*_host_*" version specifiers correspond to `11.5` packages and the +# package. The "*_host_*" version specifiers correspond to `11.8` packages and the # "*_run_*" version specifiers correspond to `11.x` packages. libcublas_host_version: - - ">=11.7.3.1,<=11.7.4.6" + - "=11.11.3.6" libcublas_run_version: - - ">=11.5.2.43,<=11.11.3.6" + - ">=11.5.2.43,<12.0.0" libcurand_host_version: - - ">=10.2.6.48,<=10.2.7.107" + - "=10.3.0.86" libcurand_run_version: - - ">=10.2.5.43,<=10.3.0.86" + - ">=10.2.5.43,<10.3.1" libcusolver_host_version: - - ">=11.2.1.48,<=11.3.2.107" + - "=11.4.1.48" libcusolver_run_version: - - ">=11.2.0.43,<=11.4.1.48" + - ">=11.2.0.43,<11.4.2" libcusparse_host_version: - - ">=11.7.0.31,<=11.7.0.107" + - "=11.7.5.86" libcusparse_run_version: - - ">=11.6.0.43,<=11.7.5.86" + - ">=11.6.0.43,<12.0.0" # `cuda-profiler-api` only has `11.8.0` and `12.0.0` packages for all # architectures. The "*_host_*" version specifiers correspond to `11.8` packages and the # "*_run_*" version specifiers correspond to `11.x` packages. cuda_profiler_api_host_version: - - ">=11.8.86,<12" + - "=11.8.86" cuda_profiler_api_run_version: - ">=11.4.240,<12" diff --git a/conda/recipes/raft-dask/conda_build_config.yaml b/conda/recipes/raft-dask/conda_build_config.yaml index 42d7e3a900..153fd2129e 100644 --- a/conda/recipes/raft-dask/conda_build_config.yaml +++ b/conda/recipes/raft-dask/conda_build_config.yaml @@ -14,7 +14,7 @@ ucx_version: - "1.13.0" ucx_py_version: - - "0.30.*" + - "0.31.*" cmake_version: - ">=3.23.1,!=3.25.0" diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 784bbbb935..5a89c735bb 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -10,8 +10,8 @@ # is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express # or implied. See the License for the specific language governing permissions and limitations under # the License. -set(RAPIDS_VERSION "23.02") -set(RAFT_VERSION "23.02.00") +set(RAPIDS_VERSION "23.04") +set(RAFT_VERSION "23.04.00") cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) include(../fetch_rapids.cmake) @@ -284,7 +284,18 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/update_centroids_double.cu src/distance/cluster/cluster_cost_float.cu src/distance/cluster/cluster_cost_double.cu - src/distance/neighbors/refine.cu + src/distance/neighbors/refine_d_uint64_t_float.cu + src/distance/neighbors/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/refine_h_uint64_t_float.cu + src/distance/neighbors/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/refine_h_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_float.cu + src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_float.cu + src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu + src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu src/distance/neighbors/ivfpq_search.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu @@ -665,6 +676,13 @@ raft_export( distance distributed nn DOCUMENTATION doc_string NAMESPACE raft:: FINAL_CODE_BLOCK code_string ) +# ################################################################################################## +# * shared test/bench headers ------------------------------------------------ + +if(BUILD_TESTS OR BUILD_BENCH) + include(internal/CMakeLists.txt) +endif() + # ################################################################################################## # * build test executable ---------------------------------------------------- diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index 99606dd2e9..1bc2c86243 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -30,6 +30,7 @@ function(ConfigureBench) target_link_libraries( ${BENCH_NAME} PRIVATE raft::raft + raft_internal $<$:raft::distance> $<$:raft::nn> benchmark::benchmark @@ -81,6 +82,7 @@ if(BUILD_BENCH) bench/distance/distance_l1.cu bench/distance/distance_unexp_l2.cu bench/distance/fused_l2_nn.cu + bench/distance/masked_nn.cu bench/distance/kernels.cu bench/main.cpp OPTIONAL @@ -102,7 +104,10 @@ if(BUILD_BENCH) bench/main.cpp ) - ConfigureBench(NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/main.cpp) + ConfigureBench( + NAME MATRIX_BENCH PATH bench/matrix/argmin.cu bench/matrix/gather.cu bench/matrix/select_k.cu + bench/main.cpp + ) ConfigureBench( NAME RANDOM_BENCH PATH bench/random/make_blobs.cu bench/random/permute.cu bench/random/rng.cu @@ -126,7 +131,6 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu bench/neighbors/knn/ivf_pq_uint8_t_uint32_t.cu bench/neighbors/refine.cu - bench/neighbors/selection.cu bench/main.cpp OPTIONAL DIST diff --git a/cpp/bench/cluster/kmeans_balanced.cu b/cpp/bench/cluster/kmeans_balanced.cu index 210b40ced8..9c53e86d8c 100644 --- a/cpp/bench/cluster/kmeans_balanced.cu +++ b/cpp/bench/cluster/kmeans_balanced.cu @@ -15,20 +15,19 @@ */ #include +#include #include -#include -#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED -#include +#if defined RAFT_DISTANCE_COMPILED +#include #endif namespace raft::bench::cluster { struct KMeansBalancedBenchParams { DatasetParams data; - uint32_t max_iter; uint32_t n_lists; - raft::distance::DistanceType metric; + raft::cluster::kmeans_balanced_params kb_params; }; template @@ -38,15 +37,10 @@ struct KMeansBalanced : public fixture { void run_benchmark(::benchmark::State& state) override { this->loop_on_state(state, [this]() { - raft::spatial::knn::detail::kmeans::build_hierarchical(this->handle, - this->params.max_iter, - (uint32_t)this->params.data.cols, - this->X.data_handle(), - this->params.data.rows, - this->centroids.data_handle(), - this->params.n_lists, - this->params.metric, - this->handle.get_stream()); + raft::device_matrix_view X_view = this->X.view(); + raft::device_matrix_view centroids_view = this->centroids.view(); + raft::cluster::kmeans_balanced::fit( + this->handle, this->params.kb_params, X_view, centroids_view); }); } @@ -84,8 +78,8 @@ std::vector getKMeansBalancedInputs() std::vector out; KMeansBalancedBenchParams p; p.data.row_major = true; - p.max_iter = 20; - p.metric = raft::distance::DistanceType::L2Expanded; + p.kb_params.n_iters = 20; + p.kb_params.metric = raft::distance::DistanceType::L2Expanded; std::vector> row_cols = { {100000, 128}, {1000000, 128}, {10000000, 128}, // The following dataset sizes are too large for most GPUs. @@ -104,7 +98,5 @@ std::vector getKMeansBalancedInputs() // Note: the datasets sizes are too large for 32-bit index types. RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); -RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); -RAFT_BENCH_REGISTER((KMeansBalanced), "", getKMeansBalancedInputs()); } // namespace raft::bench::cluster diff --git a/cpp/bench/common/benchmark.hpp b/cpp/bench/common/benchmark.hpp index 13ca40a033..85d5381e2c 100644 --- a/cpp/bench/common/benchmark.hpp +++ b/cpp/bench/common/benchmark.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include @@ -110,7 +110,7 @@ class fixture { rmm::device_buffer scratch_buf_; public: - raft::handle_t handle; + raft::device_resources handle; rmm::cuda_stream_view stream; fixture() : stream{handle.get_stream()} diff --git a/cpp/bench/distance/distance_common.cuh b/cpp/bench/distance/distance_common.cuh index 73faacce37..1be00ec0c7 100644 --- a/cpp/bench/distance/distance_common.cuh +++ b/cpp/bench/distance/distance_common.cuh @@ -24,14 +24,14 @@ namespace raft::bench::distance { -struct distance_inputs { +struct distance_params { int m, n, k; bool isRowMajor; -}; // struct distance_inputs +}; // struct distance_params template struct distance : public fixture { - distance(const distance_inputs& p) + distance(const distance_params& p) : params(p), x(p.m * p.k, stream), y(p.n * p.k, stream), @@ -63,13 +63,13 @@ struct distance : public fixture { } private: - distance_inputs params; + distance_params params; rmm::device_uvector x, y, out; rmm::device_uvector workspace; size_t worksize; }; // struct Distance -const std::vector dist_input_vecs{ +const std::vector dist_input_vecs{ {32, 16384, 16384, true}, {64, 16384, 16384, true}, {128, 16384, 16384, true}, {256, 16384, 16384, true}, {512, 16384, 16384, true}, {1024, 16384, 16384, true}, {16384, 32, 16384, true}, {16384, 64, 16384, true}, {16384, 128, 16384, true}, diff --git a/cpp/bench/distance/kernels.cu b/cpp/bench/distance/kernels.cu index 5c9c2cc2ed..027f93171e 100644 --- a/cpp/bench/distance/kernels.cu +++ b/cpp/bench/distance/kernels.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include #include -#include +#include #include #include #include @@ -77,7 +77,7 @@ struct GramMatrix : public fixture { } private: - const raft::handle_t handle; + const raft::device_resources handle; std::unique_ptr> kernel; GramTestParams params; diff --git a/cpp/bench/distance/masked_nn.cu b/cpp/bench/distance/masked_nn.cu new file mode 100644 index 0000000000..3677d44864 --- /dev/null +++ b/cpp/bench/distance/masked_nn.cu @@ -0,0 +1,267 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined RAFT_NN_COMPILED +#include +#endif + +namespace raft::bench::distance::masked_nn { + +// Introduce various sparsity patterns +enum AdjacencyPattern { + checkerboard = 0, + checkerboard_4 = 1, + checkerboard_64 = 2, + all_true = 3, + all_false = 4 +}; + +struct Params { + int m, n, k, num_groups; + AdjacencyPattern pattern; +}; // struct Params + +__global__ void init_adj(AdjacencyPattern pattern, + int n, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs) +{ + int m = adj.extent(0); + int num_groups = adj.extent(1); + + for (int idx_m = blockIdx.y * blockDim.y + threadIdx.y; idx_m < m; + idx_m += blockDim.y * gridDim.y) { + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; + idx_g += blockDim.x * gridDim.x) { + switch (pattern) { + case checkerboard: adj(idx_m, idx_g) = (idx_m + idx_g) % 2; break; + case checkerboard_4: adj(idx_m, idx_g) = (idx_m / 4 + idx_g) % 2; break; + case checkerboard_64: adj(idx_m, idx_g) = (idx_m / 64 + idx_g) % 2; break; + case all_true: adj(idx_m, idx_g) = true; break; + case all_false: adj(idx_m, idx_g) = false; break; + default: assert(false && "unknown pattern"); + } + } + } + // Each group is of size n / num_groups. + // + // - group_idxs[j] indicates the start of group j + 1 (i.e. is the inclusive + // scan of the group lengths) + // + // - The first group always starts at index zero, so we do not store it. + // + // - The group_idxs[num_groups - 1] should always equal n. + + if (blockIdx.y == 0 && threadIdx.y == 0) { + const int g_stride = blockDim.x * gridDim.x; + for (int idx_g = blockIdx.x * blockDim.x + threadIdx.x; idx_g < num_groups; idx_g += g_stride) { + group_idxs(idx_g) = (idx_g + 1) * (n / num_groups); + } + group_idxs(num_groups - 1) = n; + } +} + +template +struct masked_l2_nn : public fixture { + using DataT = T; + using IdxT = int; + using OutT = raft::KeyValuePair; + using RedOpT = raft::distance::MinAndDistanceReduceOp; + using PairRedOpT = raft::distance::KVPMinReduce; + using ParamT = raft::distance::MaskedL2NNParams; + + // Parameters + Params params; + // Data + raft::device_vector out; + raft::device_matrix x, y; + raft::device_vector xn, yn; + raft::device_matrix adj; + raft::device_vector group_idxs; + + masked_l2_nn(const Params& p) + : params(p), + out{raft::make_device_vector(handle, p.m)}, + x{raft::make_device_matrix(handle, p.m, p.k)}, + y{raft::make_device_matrix(handle, p.n, p.k)}, + xn{raft::make_device_vector(handle, p.m)}, + yn{raft::make_device_vector(handle, p.n)}, + adj{raft::make_device_matrix(handle, p.m, p.num_groups)}, + group_idxs{raft::make_device_vector(handle, p.num_groups)} + { + raft::random::RngState r(123456ULL); + + uniform(handle, r, x.data_handle(), p.m * p.k, T(-1.0), T(1.0)); + uniform(handle, r, y.data_handle(), p.n * p.k, T(-1.0), T(1.0)); + raft::linalg::rowNorm( + xn.data_handle(), x.data_handle(), p.k, p.m, raft::linalg::L2Norm, true, stream); + raft::linalg::rowNorm( + yn.data_handle(), y.data_handle(), p.k, p.n, raft::linalg::L2Norm, true, stream); + raft::distance::initialize, int>( + handle, out.data_handle(), p.m, std::numeric_limits::max(), RedOpT{}); + + dim3 block(32, 32); + dim3 grid(10, 10); + init_adj<<>>(p.pattern, p.n, adj.view(), group_idxs.view()); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + void run_benchmark(::benchmark::State& state) override + { + bool init_out = true; + bool sqrt = false; + ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; + + loop_on_state(state, [this, masked_l2_params]() { + // It is sufficient to only benchmark the L2-squared metric + raft::distance::maskedL2NN(handle, + masked_l2_params, + x.view(), + y.view(), + xn.view(), + yn.view(), + adj.view(), + group_idxs.view(), + out.view()); + }); + + // Virtual flop count if no skipping had occurred. + size_t virtual_flops = size_t(2) * size_t(params.m) * size_t(params.n) * size_t(params.k); + + int64_t read_elts = params.n * params.k + params.m * params.k; + int64_t write_elts = params.m; + + // Virtual min flops is the number of flops that would have been executed if + // the algorithm had actually skipped each computation that it could have + // skipped. + size_t virtual_min_flops = 0; + switch (params.pattern) { + case checkerboard: + case checkerboard_4: + case checkerboard_64: virtual_min_flops = virtual_flops / 2; break; + case all_true: virtual_min_flops = virtual_flops; break; + case all_false: virtual_min_flops = 0; break; + default: assert(false && "unknown pattern"); + } + + // VFLOP/s is the "virtual" flop count that would have executed if there was + // no adjacency pattern. This is useful for comparing to fusedL2NN + state.counters["VFLOP/s"] = benchmark::Counter(virtual_flops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + // Virtual min flops is the number of flops that would have been executed if + // the algorithm had actually skipped each computation that it could have + // skipped. + state.counters["VminFLOP/s"] = benchmark::Counter(virtual_min_flops, + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["BW Wr"] = benchmark::Counter(write_elts * sizeof(OutT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + state.counters["BW Rd"] = benchmark::Counter(read_elts * sizeof(DataT), + benchmark::Counter::kIsIterationInvariantRate, + benchmark::Counter::OneK::kIs1000); + + state.counters["m"] = benchmark::Counter(params.m); + state.counters["n"] = benchmark::Counter(params.n); + state.counters["k"] = benchmark::Counter(params.k); + state.counters["num_groups"] = benchmark::Counter(params.num_groups); + state.counters["group size"] = benchmark::Counter(params.n / params.num_groups); + state.counters["Pat"] = benchmark::Counter(static_cast(params.pattern)); + + state.counters["SM count"] = raft::getMultiProcessorCount(); + } +}; // struct MaskedL2NN + +const std::vector masked_l2_nn_input_vecs = { + // Very fat matrices... + {32, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {64, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {128, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {256, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {512, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {1024, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 32, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 64, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 128, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 256, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 512, 16384, 32, AdjacencyPattern::checkerboard}, + {16384, 1024, 16384, 32, AdjacencyPattern::checkerboard}, + + // Representative matrices... + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard}, + + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_4}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_4}, + + {16384, 16384, 32, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 64, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 128, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 256, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 512, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 1024, 32, AdjacencyPattern::checkerboard_64}, + {16384, 16384, 16384, 32, AdjacencyPattern::checkerboard_64}, + + {16384, 16384, 32, 32, AdjacencyPattern::all_true}, + {16384, 16384, 64, 32, AdjacencyPattern::all_true}, + {16384, 16384, 128, 32, AdjacencyPattern::all_true}, + {16384, 16384, 256, 32, AdjacencyPattern::all_true}, + {16384, 16384, 512, 32, AdjacencyPattern::all_true}, + {16384, 16384, 1024, 32, AdjacencyPattern::all_true}, + {16384, 16384, 16384, 32, AdjacencyPattern::all_true}, + + {16384, 16384, 32, 32, AdjacencyPattern::all_false}, + {16384, 16384, 64, 32, AdjacencyPattern::all_false}, + {16384, 16384, 128, 32, AdjacencyPattern::all_false}, + {16384, 16384, 256, 32, AdjacencyPattern::all_false}, + {16384, 16384, 512, 32, AdjacencyPattern::all_false}, + {16384, 16384, 1024, 32, AdjacencyPattern::all_false}, + {16384, 16384, 16384, 32, AdjacencyPattern::all_false}, +}; + +RAFT_BENCH_REGISTER(masked_l2_nn, "", masked_l2_nn_input_vecs); +// We don't benchmark double to keep compile times in check when not using the +// distance library. + +} // namespace raft::bench::distance::masked_nn diff --git a/cpp/bench/matrix/argmin.cu b/cpp/bench/matrix/argmin.cu index 0d0dea0fdb..3869f0c5e1 100644 --- a/cpp/bench/matrix/argmin.cu +++ b/cpp/bench/matrix/argmin.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,10 +17,11 @@ #include #include #include +#include #include -namespace raft::bench::linalg { +namespace raft::bench::matrix { template struct ArgminParams { @@ -45,9 +46,7 @@ struct Argmin : public fixture { void run_benchmark(::benchmark::State& state) override { loop_on_state(state, [this]() { - auto matrix_const_view = raft::make_device_matrix_view( - matrix.data_handle(), matrix.extent(0), matrix.extent(1)); - raft::matrix::argmin(handle, matrix_const_view, indices.view()); + raft::matrix::argmin(handle, raft::make_const_mdspan(matrix.view()), indices.view()); }); } @@ -57,15 +56,11 @@ struct Argmin : public fixture { raft::device_vector indices; }; // struct Argmin -const std::vector> argmin_inputs_i64{ - {1000, 64}, {1000, 128}, {1000, 256}, {1000, 512}, {1000, 1024}, - {10000, 64}, {10000, 128}, {10000, 256}, {10000, 512}, {10000, 1024}, - {100000, 64}, {100000, 128}, {100000, 256}, {100000, 512}, {100000, 1024}, - {1000000, 64}, {1000000, 128}, {1000000, 256}, {1000000, 512}, {1000000, 1024}, - {10000000, 64}, {10000000, 128}, {10000000, 256}, {10000000, 512}, {10000000, 1024}, -}; +const std::vector> argmin_inputs_i64 = + raft::util::itertools::product>({1000, 10000, 100000, 1000000, 10000000}, + {64, 128, 256, 512, 1024}); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); RAFT_BENCH_REGISTER((Argmin), "", argmin_inputs_i64); -} // namespace raft::bench::linalg +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/gather.cu b/cpp/bench/matrix/gather.cu new file mode 100644 index 0000000000..c5d80744cd --- /dev/null +++ b/cpp/bench/matrix/gather.cu @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include + +#include + +namespace raft::bench::matrix { + +template +struct GatherParams { + IdxT rows, cols, map_length; +}; + +template +inline auto operator<<(std::ostream& os, const GatherParams& p) -> std::ostream& +{ + os << p.rows << "#" << p.cols << "#" << p.map_length; + return os; +} + +template +struct Gather : public fixture { + Gather(const GatherParams& p) : params(p) {} + + void allocate_data(const ::benchmark::State& state) override + { + matrix = raft::make_device_matrix(handle, params.rows, params.cols); + map = raft::make_device_vector(handle, params.map_length); + out = raft::make_device_matrix(handle, params.map_length, params.cols); + stencil = raft::make_device_vector(handle, Conditional ? params.map_length : IdxT(0)); + + raft::random::RngState rng{1234}; + raft::random::uniform( + rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream); + raft::random::uniformInt( + handle, rng, map.data_handle(), params.map_length, (MapT)0, (MapT)params.rows); + if constexpr (Conditional) { + raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream); + } + handle.sync_stream(stream); + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + loop_on_state(state, [this]() { + auto matrix_const_view = raft::make_const_mdspan(matrix.view()); + auto map_const_view = raft::make_const_mdspan(map.view()); + if constexpr (Conditional) { + auto stencil_const_view = raft::make_const_mdspan(stencil.view()); + auto pred_op = raft::plug_const_op(T(0.0), raft::greater_op()); + raft::matrix::gather_if( + handle, matrix_const_view, out.view(), map_const_view, stencil_const_view, pred_op); + } else { + raft::matrix::gather(handle, matrix_const_view, map_const_view, out.view()); + } + }); + } + + private: + GatherParams params; + raft::device_matrix matrix, out; + raft::device_vector stencil; + raft::device_vector map; +}; // struct Gather + +template +using GatherIf = Gather; + +const std::vector> gather_inputs_i64 = + raft::util::itertools::product>( + {1000000}, {10, 20, 50, 100, 200, 500}, {1000, 10000, 100000, 1000000}); + +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((Gather), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +RAFT_BENCH_REGISTER((GatherIf), "", gather_inputs_i64); +} // namespace raft::bench::matrix diff --git a/cpp/bench/matrix/select_k.cu b/cpp/bench/matrix/select_k.cu new file mode 100644 index 0000000000..2c8b8bb67b --- /dev/null +++ b/cpp/bench/matrix/select_k.cu @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#include + +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include + +namespace raft::matrix { + +using namespace raft::bench; // NOLINT + +template +struct selection : public fixture { + explicit selection(const select::params& p) + : params_(p), + in_dists_(p.batch_size * p.len, stream), + in_ids_(p.batch_size * p.len, stream), + out_dists_(p.batch_size * p.k, stream), + out_ids_(p.batch_size * p.k, stream) + { + raft::sparse::iota_fill(in_ids_.data(), IdxT(p.batch_size), IdxT(p.len), stream); + raft::random::RngState state{42}; + raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); + } + + void run_benchmark(::benchmark::State& state) override // NOLINT + { + device_resources handle{stream}; + using_pool_memory_res res; + try { + std::ostringstream label_stream; + label_stream << params_.batch_size << "#" << params_.len << "#" << params_.k; + state.SetLabel(label_stream.str()); + loop_on_state(state, [this, &handle]() { + select::select_k_impl(handle, + Algo, + in_dists_.data(), + in_ids_.data(), + params_.batch_size, + params_.len, + params_.k, + out_dists_.data(), + out_ids_.data(), + params_.select_min); + }); + } catch (raft::exception& e) { + state.SkipWithError(e.what()); + } + } + + private: + const select::params params_; + rmm::device_uvector in_dists_, out_dists_; + rmm::device_uvector in_ids_, out_ids_; +}; + +const std::vector kInputs{ + {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, + {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, + {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, + + {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, + {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, + {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, + + {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, + {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, + {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, + + {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, + {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, + {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, +}; + +#define SELECTION_REGISTER(KeyT, IdxT, A) \ + namespace BENCHMARK_PRIVATE_NAME(selection) \ + { \ + using SelectK = selection; \ + RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #A, kInputs); \ + } + +SELECTION_REGISTER(float, int, kPublicApi); // NOLINT +SELECTION_REGISTER(float, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(float, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(float, int, kWarpAuto); // NOLINT +SELECTION_REGISTER(float, int, kWarpImmediate); // NOLINT +SELECTION_REGISTER(float, int, kWarpFiltered); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributed); // NOLINT +SELECTION_REGISTER(float, int, kWarpDistributedShm); // NOLINT + +SELECTION_REGISTER(double, int, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, int, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, int, kWarpAuto); // NOLINT + +SELECTION_REGISTER(double, size_t, kRadix8bits); // NOLINT +SELECTION_REGISTER(double, size_t, kRadix11bits); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpImmediate); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpFiltered); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributed); // NOLINT +SELECTION_REGISTER(double, size_t, kWarpDistributedShm); // NOLINT + +} // namespace raft::matrix diff --git a/cpp/bench/neighbors/knn.cuh b/cpp/bench/neighbors/knn.cuh index d38631b289..eec1cba99e 100644 --- a/cpp/bench/neighbors/knn.cuh +++ b/cpp/bench/neighbors/knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -32,6 +32,9 @@ #include #if defined RAFT_DISTANCE_COMPILED #include +#include +#else +#pragma message("NN / Distance specializations are not enabled; expect very long building times.") #endif #endif @@ -148,7 +151,7 @@ struct ivf_flat_knn { raft::neighbors::ivf_flat::search_params search_params; params ps; - ivf_flat_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps) + ivf_flat_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; @@ -156,7 +159,7 @@ struct ivf_flat_knn { handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } - void search(const raft::handle_t& handle, + void search(const raft::device_resources& handle, const ValT* search_items, dist_t* out_dists, IdxT* out_idxs) @@ -176,7 +179,7 @@ struct ivf_pq_knn { raft::neighbors::ivf_pq::search_params search_params; params ps; - ivf_pq_knn(const raft::handle_t& handle, const params& ps, const ValT* data) : ps(ps) + ivf_pq_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : ps(ps) { index_params.n_lists = 4096; index_params.metric = raft::distance::DistanceType::L2Expanded; @@ -184,7 +187,7 @@ struct ivf_pq_knn { handle, index_params, data, IdxT(ps.n_samples), uint32_t(ps.n_dims))); } - void search(const raft::handle_t& handle, + void search(const raft::device_resources& handle, const ValT* search_items, dist_t* out_dists, IdxT* out_idxs) @@ -202,12 +205,12 @@ struct brute_force_knn { ValT* index; params ps; - brute_force_knn(const raft::handle_t& handle, const params& ps, const ValT* data) + brute_force_knn(const raft::device_resources& handle, const params& ps, const ValT* data) : index(const_cast(data)), ps(ps) { } - void search(const raft::handle_t& handle, + void search(const raft::device_resources& handle, const ValT* search_items, dist_t* out_dists, IdxT* out_idxs) @@ -287,7 +290,7 @@ struct knn : public fixture { std::ostringstream label_stream; label_stream << params_ << "#" << strategy_ << "#" << scope_; state.SetLabel(label_stream.str()); - raft::handle_t handle(stream); + raft::device_resources handle(stream); std::optional index; if (scope_ == Scope::SEARCH) { // also implies TransferStrategy::NO_COPY diff --git a/cpp/bench/neighbors/refine.cu b/cpp/bench/neighbors/refine.cu index a038905ace..f32af3a57e 100644 --- a/cpp/bench/neighbors/refine.cu +++ b/cpp/bench/neighbors/refine.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,18 +14,20 @@ * limitations under the License. */ -#include +#include -#include +#include #include -#include +#include #include #include #include +#include #if defined RAFT_DISTANCE_COMPILED #include +#include #endif #if defined RAFT_NN_COMPILED @@ -36,12 +38,10 @@ #include #include -#include "../../test/neighbors/refine_helper.cuh" - #include #include -using namespace raft::neighbors::detail; +using namespace raft::neighbors; namespace raft::bench::neighbors { @@ -53,7 +53,7 @@ inline auto operator<<(std::ostream& os, const RefineInputs& p) -> std::os return os; } -RefineInputs p; +RefineInputs p; template class RefineAnn : public fixture { @@ -95,28 +95,28 @@ class RefineAnn : public fixture { } private: - raft::handle_t handle_; + raft::device_resources handle_; RefineHelper data; }; -std::vector> getInputs() +std::vector> getInputs() { - std::vector> out; + std::vector> out; raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; for (bool host_data : {true, false}) { - for (int64_t n_queries : {1000, 10000}) { - for (int64_t dim : {128, 512}) { - out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); - out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); + for (uint64_t n_queries : {1000, 10000}) { + for (uint64_t dim : {128, 512}) { + out.push_back(RefineInputs{n_queries, 2000000, dim, 32, 128, metric, host_data}); + out.push_back(RefineInputs{n_queries, 2000000, dim, 10, 40, metric, host_data}); } } } return out; } -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); +using refine_float_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_uint64, "", getInputs()); -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +using refine_uint8_uint64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_uint64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/bench/neighbors/selection.cu b/cpp/bench/neighbors/selection.cu deleted file mode 100644 index 1f116c199f..0000000000 --- a/cpp/bench/neighbors/selection.cu +++ /dev/null @@ -1,123 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#if defined RAFT_NN_COMPILED -#include -#endif - -#include -#include - -#include -#include - -namespace raft::bench::spatial { - -struct params { - int n_inputs; - int input_len; - int k; - int select_min; -}; - -template -struct selection : public fixture { - explicit selection(const params& p) - : params_(p), - in_dists_(p.n_inputs * p.input_len, stream), - in_ids_(p.n_inputs * p.input_len, stream), - out_dists_(p.n_inputs * p.k, stream), - out_ids_(p.n_inputs * p.k, stream) - { - raft::sparse::iota_fill(in_ids_.data(), IdxT(p.n_inputs), IdxT(p.input_len), stream); - raft::random::RngState state{42}; - raft::random::uniform(handle, state, in_dists_.data(), in_dists_.size(), KeyT(-1.0), KeyT(1.0)); - } - - void run_benchmark(::benchmark::State& state) override - { - using_pool_memory_res res; - try { - std::ostringstream label_stream; - label_stream << params_.n_inputs << "#" << params_.input_len << "#" << params_.k; - state.SetLabel(label_stream.str()); - loop_on_state(state, [this]() { - raft::spatial::knn::select_k(in_dists_.data(), - in_ids_.data(), - params_.n_inputs, - params_.input_len, - out_dists_.data(), - out_ids_.data(), - params_.select_min, - params_.k, - stream, - Algo); - }); - } catch (raft::exception& e) { - state.SkipWithError(e.what()); - } - } - - private: - const params params_; - rmm::device_uvector in_dists_, out_dists_; - rmm::device_uvector in_ids_, out_ids_; -}; - -const std::vector kInputs{ - {20000, 500, 1, true}, {20000, 500, 2, true}, {20000, 500, 4, true}, - {20000, 500, 8, true}, {20000, 500, 16, true}, {20000, 500, 32, true}, - {20000, 500, 64, true}, {20000, 500, 128, true}, {20000, 500, 256, true}, - - {1000, 10000, 1, true}, {1000, 10000, 2, true}, {1000, 10000, 4, true}, - {1000, 10000, 8, true}, {1000, 10000, 16, true}, {1000, 10000, 32, true}, - {1000, 10000, 64, true}, {1000, 10000, 128, true}, {1000, 10000, 256, true}, - - {100, 100000, 1, true}, {100, 100000, 2, true}, {100, 100000, 4, true}, - {100, 100000, 8, true}, {100, 100000, 16, true}, {100, 100000, 32, true}, - {100, 100000, 64, true}, {100, 100000, 128, true}, {100, 100000, 256, true}, - - {10, 1000000, 1, true}, {10, 1000000, 2, true}, {10, 1000000, 4, true}, - {10, 1000000, 8, true}, {10, 1000000, 16, true}, {10, 1000000, 32, true}, - {10, 1000000, 64, true}, {10, 1000000, 128, true}, {10, 1000000, 256, true}, -}; - -#define SELECTION_REGISTER(KeyT, IdxT, Algo) \ - namespace BENCHMARK_PRIVATE_NAME(selection) \ - { \ - using SelectK = selection; \ - RAFT_BENCH_REGISTER(SelectK, #KeyT "/" #IdxT "/" #Algo, kInputs); \ - } - -SELECTION_REGISTER(float, int, FAISS); -SELECTION_REGISTER(float, int, RADIX_8_BITS); -SELECTION_REGISTER(float, int, RADIX_11_BITS); -SELECTION_REGISTER(float, int, WARP_SORT); - -SELECTION_REGISTER(double, int, FAISS); -SELECTION_REGISTER(double, int, RADIX_8_BITS); -SELECTION_REGISTER(double, int, RADIX_11_BITS); -SELECTION_REGISTER(double, int, WARP_SORT); - -SELECTION_REGISTER(double, size_t, FAISS); -SELECTION_REGISTER(double, size_t, RADIX_8_BITS); -SELECTION_REGISTER(double, size_t, RADIX_11_BITS); -SELECTION_REGISTER(double, size_t, WARP_SORT); - -} // namespace raft::bench::spatial diff --git a/cpp/bench/random/permute.cu b/cpp/bench/random/permute.cu index 5364bb44e3..cb9e21868b 100644 --- a/cpp/bench/random/permute.cu +++ b/cpp/bench/random/permute.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ struct permute : public fixture { } private: - raft::handle_t handle; + raft::device_resources handle; permute_inputs params; rmm::device_uvector out, in; rmm::device_uvector perms; diff --git a/cpp/bench/sparse/convert_csr.cu b/cpp/bench/sparse/convert_csr.cu index 830fab13cc..c9dcae6985 100644 --- a/cpp/bench/sparse/convert_csr.cu +++ b/cpp/bench/sparse/convert_csr.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -107,7 +107,7 @@ struct bench_base : public fixture { } protected: - raft::handle_t handle; + raft::device_resources handle; bench_param params; rmm::device_uvector adj; rmm::device_uvector row_ind; diff --git a/cpp/cmake/thirdparty/get_cutlass.cmake b/cpp/cmake/thirdparty/get_cutlass.cmake index 811a5466c3..3e02ce064e 100644 --- a/cpp/cmake/thirdparty/get_cutlass.cmake +++ b/cpp/cmake/thirdparty/get_cutlass.cmake @@ -30,6 +30,10 @@ function(find_and_configure_cutlass) CACHE BOOL "Disable CUTLASS to build with cuBLAS library." ) + if (CUDA_STATIC_RUNTIME) + set(CUDART_LIBRARY "${CUDA_cudart_static_LIBRARY}" CACHE FILEPATH "fixing cutlass cmake code" FORCE) + endif() + rapids_cpm_find( NvidiaCutlass ${PKG_VERSION} GLOBAL_TARGETS nvidia::cutlass::cutlass diff --git a/cpp/include/raft/cluster/detail/agglomerative.cuh b/cpp/include/raft/cluster/detail/agglomerative.cuh index 618f852bba..f4b2ecf051 100644 --- a/cpp/include/raft/cluster/detail/agglomerative.cuh +++ b/cpp/include/raft/cluster/detail/agglomerative.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -100,7 +100,7 @@ class UnionFind { * @param[out] out_size cluster sizes of output */ template -void build_dendrogram_host(const handle_t& handle, +void build_dendrogram_host(raft::device_resources const& handle, const value_idx* rows, const value_idx* cols, const value_t* data, @@ -236,7 +236,7 @@ struct init_label_roots { * @param n_leaves */ template -void extract_flattened_clusters(const raft::handle_t& handle, +void extract_flattened_clusters(raft::device_resources const& handle, value_idx* labels, const value_idx* children, size_t n_clusters, diff --git a/cpp/include/raft/cluster/detail/connectivities.cuh b/cpp/include/raft/cluster/detail/connectivities.cuh index a07045f0d2..163670f29a 100644 --- a/cpp/include/raft/cluster/detail/connectivities.cuh +++ b/cpp/include/raft/cluster/detail/connectivities.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -24,6 +24,7 @@ #include #include +#include #include #include #include @@ -39,7 +40,7 @@ namespace raft::cluster::detail { template struct distance_graph_impl { - void run(const raft::handle_t& handle, + void run(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, @@ -57,7 +58,7 @@ struct distance_graph_impl { */ template struct distance_graph_impl { - void run(const raft::handle_t& handle, + void run(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, @@ -103,6 +104,98 @@ struct distance_graph_impl +__global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz) +{ + value_idx tid = (blockIdx.x * blockDim.x) + threadIdx.x; + if (tid >= nnz) return; + value_idx v = tid % m; + indices[tid] = v; +} + +/** + * Compute connected CSR of pairwise distances + * @tparam value_idx + * @tparam value_t + * @param handle + * @param X + * @param m + * @param n + * @param metric + * @param[out] indptr + * @param[out] indices + * @param[out] data + */ +template +void pairwise_distances(const raft::device_resources& handle, + const value_t* X, + size_t m, + size_t n, + raft::distance::DistanceType metric, + value_idx* indptr, + value_idx* indices, + value_t* data) +{ + auto stream = handle.get_stream(); + auto exec_policy = handle.get_thrust_policy(); + + value_idx nnz = m * m; + + value_idx blocks = raft::ceildiv(nnz, (value_idx)256); + fill_indices2<<>>(indices, m, nnz); + + thrust::sequence(exec_policy, indptr, indptr + m, 0, (int)m); + + raft::update_device(indptr + m, &nnz, 1, stream); + + // TODO: It would ultimately be nice if the MST could accept + // dense inputs directly so we don't need to double the memory + // usage to hand it a sparse array here. + distance::pairwise_distance(handle, X, X, data, m, m, n, metric); + // self-loops get max distance + auto transform_in = + thrust::make_zip_iterator(thrust::make_tuple(thrust::make_counting_iterator(0), data)); + + thrust::transform(exec_policy, + transform_in, + transform_in + nnz, + data, + [=] __device__(const thrust::tuple& tup) { + value_idx idx = thrust::get<0>(tup); + bool self_loop = idx % m == idx / m; + return (self_loop * std::numeric_limits::max()) + + (!self_loop * thrust::get<1>(tup)); + }); +} + +/** + * Connectivities specialization for pairwise distances + * @tparam value_idx + * @tparam value_t + */ +template +struct distance_graph_impl { + void run(const raft::device_resources& handle, + const value_t* X, + size_t m, + size_t n, + raft::distance::DistanceType metric, + rmm::device_uvector& indptr, + rmm::device_uvector& indices, + rmm::device_uvector& data, + int c) + { + auto stream = handle.get_stream(); + + size_t nnz = m * m; + + indices.resize(nnz, stream); + data.resize(nnz, stream); + + pairwise_distances(handle, X, m, n, metric, indptr.data(), indices.data(), data.data()); + } +}; + /** * Returns a CSR connectivities graph based on the given linkage distance. * @tparam value_idx @@ -120,7 +213,7 @@ struct distance_graph_impl -void get_distance_graph(const raft::handle_t& handle, +void get_distance_graph(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, diff --git a/cpp/include/raft/cluster/detail/kmeans.cuh b/cpp/include/raft/cluster/detail/kmeans.cuh index e575849536..9632fedb9d 100644 --- a/cpp/include/raft/cluster/detail/kmeans.cuh +++ b/cpp/include/raft/cluster/detail/kmeans.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,7 +31,7 @@ #include #include #include -#include +#include #include #include #include @@ -59,7 +59,7 @@ namespace detail { // Selects 'n_clusters' samples randomly from X template -void initRandom(const raft::handle_t& handle, +void initRandom(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids) @@ -85,7 +85,7 @@ void initRandom(const raft::handle_t& handle, * 5: end for */ template -void kmeansPlusPlus(const raft::handle_t& handle, +void kmeansPlusPlus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -282,7 +282,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, * @param[inout] workspace */ template -void update_centroids(const raft::handle_t& handle, +void update_centroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view sample_weights, raft::device_matrix_view centroids, @@ -356,7 +356,7 @@ void update_centroids(const raft::handle_t& handle, // TODO: Resizing is needed to use mdarray instead of rmm::device_uvector template -void kmeans_fit_main(const raft::handle_t& handle, +void kmeans_fit_main(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, @@ -573,7 +573,7 @@ void kmeans_fit_main(const raft::handle_t& handle, */ template -void initScalableKMeansPlusPlus(const raft::handle_t& handle, +void initScalableKMeansPlusPlus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -816,7 +816,7 @@ void initScalableKMeansPlusPlus(const raft::handle_t& handle, * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -955,7 +955,7 @@ void kmeans_fit(handle_t const& handle, } template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -980,7 +980,7 @@ void kmeans_fit(handle_t const& handle, } template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -1088,7 +1088,7 @@ void kmeans_predict(handle_t const& handle, } template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -1120,7 +1120,7 @@ void kmeans_predict(handle_t const& handle, } template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -1147,7 +1147,7 @@ void kmeans_fit_predict(handle_t const& handle, } template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -1187,7 +1187,7 @@ void kmeans_fit_predict(handle_t const& handle, * @param[out] X_new X transformed in the new space.. */ template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -1228,7 +1228,7 @@ void kmeans_transform(const raft::handle_t& handle, } template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, diff --git a/cpp/include/raft/cluster/detail/kmeans_balanced.cuh b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh new file mode 100644 index 0000000000..3d23c809c3 --- /dev/null +++ b/cpp/include/raft/cluster/detail/kmeans_balanced.cuh @@ -0,0 +1,1095 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +namespace raft::cluster::detail { + +constexpr static inline float kAdjustCentersWeight = 7.0f; + +/** + * @brief Predict labels for the dataset; floating-point types only. + * + * NB: no minibatch splitting is done here, it may require large amount of temporary memory (n_rows + * * n_cluster * sizeof(MathT)). + * + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * + * @param[in] handle The raft handle. + * @param[in] params Structure containing the hyper-parameters + * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] + * @param[in] n_clusters Number of clusters/centers + * @param[in] dim Dimensionality of the data + * @param[in] dataset Pointer to the data [n_rows, dim] + * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] + * @param[in] n_rows Number samples in the `dataset` + * @param[out] labels Output predictions [n_rows] + * @param[inout] mr (optional) Memory resource to use for temporary allocations + */ +template +inline std::enable_if_t> predict_core( + const raft::device_resources& handle, + const kmeans_balanced_params& params, + const MathT* centers, + IdxT n_clusters, + IdxT dim, + const MathT* dataset, + const MathT* dataset_norm, + IdxT n_rows, + LabelT* labels, + rmm::mr::device_memory_resource* mr) +{ + auto stream = handle.get_stream(); + switch (params.metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2SqrtExpanded: { + auto workspace = raft::make_device_mdarray( + handle, mr, make_extents((sizeof(int)) * n_rows)); + + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + handle, mr, make_extents(n_rows)); + raft::KeyValuePair initial_value(0, std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + minClusterAndDistance.size(), + initial_value); + + auto centroidsNorm = + raft::make_device_mdarray(handle, mr, make_extents(n_clusters)); + raft::linalg::rowNorm( + centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); + + raft::distance::fusedL2NNMinReduce, IdxT>( + minClusterAndDistance.data_handle(), + dataset, + centers, + dataset_norm, + centroidsNorm.data_handle(), + n_rows, + n_clusters, + dim, + (void*)workspace.data_handle(), + (params.metric == raft::distance::DistanceType::L2Expanded) ? false : true, + false, + stream); + + // todo(lsugy): use KVP + iterator in caller. + // Copy keys to output labels + thrust::transform(handle.get_thrust_policy(), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + n_rows, + labels, + raft::compose_op, raft::key_op>()); + break; + } + case raft::distance::DistanceType::InnerProduct: { + // TODO: pass buffer + rmm::device_uvector distances(n_rows * n_clusters, stream, mr); + + MathT alpha = -1.0; + MathT beta = 0.0; + + linalg::gemm(handle, + true, + false, + n_clusters, + n_rows, + dim, + &alpha, + centers, + dim, + dataset, + dim, + &beta, + distances.data(), + n_clusters, + stream); + + auto distances_const_view = raft::make_device_matrix_view( + distances.data(), n_rows, n_clusters); + auto labels_view = raft::make_device_vector_view(labels, n_rows); + raft::matrix::argmin(handle, distances_const_view, labels_view); + break; + } + default: { + RAFT_FAIL("The chosen distance metric is not supported (%d)", int(params.metric)); + } + } +} + +/** + * @brief Suggest a minibatch size for kmeans prediction. + * + * This function is used as a heuristic to split the work over a large dataset + * to reduce the size of temporary memory allocations. + * + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * + * @param[in] n_clusters number of clusters in kmeans clustering + * @param[in] n_rows Number of samples in the dataset + * @param[in] dim Number of features in the dataset + * @param[in] metric Distance metric + * @param[in] needs_conversion Whether the data needs to be converted to MathT + * @return A suggested minibatch size and the expected memory cost per-row (in bytes) + */ +template +constexpr auto calc_minibatch_size(IdxT n_clusters, + IdxT n_rows, + IdxT dim, + raft::distance::DistanceType metric, + bool needs_conversion) -> std::tuple +{ + n_clusters = std::max(1, n_clusters); + + // Estimate memory needs per row (i.e element of the batch). + size_t mem_per_row = 0; + switch (metric) { + // fusedL2NN needs a mutex and a key-value pair for each row. + case distance::DistanceType::L2Expanded: + case distance::DistanceType::L2SqrtExpanded: { + mem_per_row += sizeof(int); + mem_per_row += sizeof(raft::KeyValuePair); + } break; + // Other metrics require storing a distance matrix. + default: { + mem_per_row += sizeof(MathT) * n_clusters; + } + } + + // If we need to convert to MathT, space required for the converted batch. + if (!needs_conversion) { mem_per_row += sizeof(MathT) * dim; } + + // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. + IdxT minibatch_size = (1 << 30) / mem_per_row; + minibatch_size = 64 * div_rounding_up_safe(minibatch_size, IdxT{64}); + minibatch_size = std::min(minibatch_size, n_rows); + return std::make_tuple(minibatch_size, mem_per_row); +} + +/** + * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. + * + * @note all pointers must be accessible on the device. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam CounterT counter type supported by CUDA's native atomicAdd + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle. + * @param[inout] centers Pointer to the output [n_clusters, dim] + * @param[inout] cluster_sizes Number of rows in each cluster [n_clusters] + * @param[in] n_clusters Number of clusters/centers + * @param[in] dim Dimensionality of the data + * @param[in] dataset Pointer to the data [n_rows, dim] + * @param[in] n_rows Number of samples in the `dataset` + * @param[in] labels Output predictions [n_rows] + * @param[in] reset_counters Whether to clear the output arrays before calculating. + * When set to `false`, this function may be used to update existing centers and sizes using + * the weighted average principle. + * @param[in] mapping_op Mapping operation from T to MathT + * @param[inout] mr (optional) Memory resource to use for temporary allocations on the device + */ +template +void calc_centers_and_sizes(const raft::device_resources& handle, + MathT* centers, + CounterT* cluster_sizes, + IdxT n_clusters, + IdxT dim, + const T* dataset, + IdxT n_rows, + const LabelT* labels, + bool reset_counters, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* mr = nullptr) +{ + auto stream = handle.get_stream(); + if (mr == nullptr) { mr = handle.get_workspace_resource(); } + + if (!reset_counters) { + raft::linalg::matrixVectorOp( + centers, centers, cluster_sizes, dim, n_clusters, true, false, raft::mul_op(), stream); + } + + rmm::device_uvector workspace(0, stream, mr); + + // If we reset the counters, we can compute directly the new sizes in cluster_sizes. + // If we don't reset, we compute in a temporary buffer and add in a separate step. + rmm::device_uvector temp_cluster_sizes(0, stream, mr); + CounterT* temp_sizes = cluster_sizes; + if (!reset_counters) { + temp_cluster_sizes.resize(n_clusters, stream); + temp_sizes = temp_cluster_sizes.data(); + } + + // Apply mapping only when the data and math types are different. + if constexpr (std::is_same_v) { + raft::linalg::reduce_rows_by_key( + dataset, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); + } else { + // todo(lsugy): use iterator from KV output of fusedL2NN + cub::TransformInputIterator mapping_itr(dataset, mapping_op); + raft::linalg::reduce_rows_by_key( + mapping_itr, dim, labels, nullptr, n_rows, dim, n_clusters, centers, stream, reset_counters); + } + + // Compute weight of each cluster + raft::cluster::detail::countLabels(handle, labels, temp_sizes, n_rows, n_clusters, workspace); + + // Add previous sizes if necessary + if (!reset_counters) { + raft::linalg::add(cluster_sizes, cluster_sizes, temp_sizes, n_clusters, stream); + } + + raft::linalg::matrixVectorOp(centers, + centers, + cluster_sizes, + dim, + n_clusters, + true, + false, + raft::div_checkzero_op(), + stream); +} + +/** Computes the L2 norm of the dataset, converting to MathT if necessary */ +template +void compute_norm(const raft::device_resources& handle, + MathT* dataset_norm, + const T* dataset, + IdxT dim, + IdxT n_rows, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope("compute_norm"); + auto stream = handle.get_stream(); + if (mr == nullptr) { mr = handle.get_workspace_resource(); } + rmm::device_uvector mapped_dataset(0, stream, mr); + + const MathT* dataset_ptr = nullptr; + + if (std::is_same_v) { + dataset_ptr = reinterpret_cast(dataset); + } else { + mapped_dataset.resize(n_rows * dim, stream); + + linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream); + + dataset_ptr = (const MathT*)mapped_dataset.data(); + } + + raft::linalg::rowNorm( + dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); +} + +/** + * @brief Predict labels for the dataset. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle + * @param[in] params Structure containing the hyper-parameters + * @param[in] centers Pointer to the row-major matrix of cluster centers [n_clusters, dim] + * @param[in] n_clusters Number of clusters/centers + * @param[in] dim Dimensionality of the data + * @param[in] dataset Pointer to the data [n_rows, dim] + * @param[in] n_rows Number samples in the `dataset` + * @param[out] labels Output predictions [n_rows] + * @param[in] mapping_op Mapping operation from T to MathT + * @param[inout] mr (optional) memory resource to use for temporary allocations + * @param[in] dataset_norm (optional) Pre-computed norms of each row in the dataset [n_rows] + */ +template +void predict(const raft::device_resources& handle, + const kmeans_balanced_params& params, + const MathT* centers, + IdxT n_clusters, + IdxT dim, + const T* dataset, + IdxT n_rows, + LabelT* labels, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* mr = nullptr, + const MathT* dataset_norm = nullptr) +{ + auto stream = handle.get_stream(); + common::nvtx::range fun_scope( + "predict(%zu, %u)", static_cast(n_rows), n_clusters); + if (mr == nullptr) { mr = handle.get_workspace_resource(); } + auto [max_minibatch_size, _mem_per_row] = + calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); + rmm::device_uvector cur_dataset( + std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); + bool need_compute_norm = + dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded); + rmm::device_uvector cur_dataset_norm( + need_compute_norm ? max_minibatch_size : 0, stream, mr); + const MathT* dataset_norm_ptr = nullptr; + auto cur_dataset_ptr = cur_dataset.data(); + for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { + IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); + + if constexpr (std::is_same_v) { + cur_dataset_ptr = const_cast(dataset + offset * dim); + } else { + linalg::unaryOp( + cur_dataset_ptr, dataset + offset * dim, minibatch_size * dim, mapping_op, stream); + } + + // Compute the norm now if it hasn't been pre-computed. + if (need_compute_norm) { + compute_norm( + handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr); + dataset_norm_ptr = cur_dataset_norm.data(); + } else if (dataset_norm != nullptr) { + dataset_norm_ptr = dataset_norm + offset; + } + + predict_core(handle, + params, + centers, + n_clusters, + dim, + cur_dataset_ptr, + dataset_norm_ptr, + minibatch_size, + labels + offset, + mr); + } +} + +template +__global__ void __launch_bounds__((WarpSize * BlockDimY)) + adjust_centers_kernel(MathT* centers, // [n_clusters, dim] + IdxT n_clusters, + IdxT dim, + const T* dataset, // [n_rows, dim] + IdxT n_rows, + const LabelT* labels, // [n_rows] + const CounterT* cluster_sizes, // [n_clusters] + MathT threshold, + IdxT average, + IdxT seed, + IdxT* count, + MappingOpT mapping_op) +{ + IdxT l = threadIdx.y + BlockDimY * static_cast(blockIdx.y); + if (l >= n_clusters) return; + auto csize = static_cast(cluster_sizes[l]); + // skip big clusters + if (csize > static_cast(average * threshold)) return; + + // choose a "random" i that belongs to a rather large cluster + IdxT i; + IdxT j = laneId(); + if (j == 0) { + do { + auto old = atomicAdd(count, IdxT{1}); + i = (seed * (old + 1)) % n_rows; + } while (static_cast(cluster_sizes[labels[i]]) < average); + } + i = raft::shfl(i, 0); + + // Adjust the center of the selected smaller cluster to gravitate towards + // a sample from the selected larger cluster. + const IdxT li = static_cast(labels[i]); + // Weight of the current center for the weighted average. + // We dump it for anomalously small clusters, but keep constant otherwise. + const MathT wc = min(static_cast(csize), static_cast(kAdjustCentersWeight)); + // Weight for the datapoint used to shift the center. + const MathT wd = 1.0; + for (; j < dim; j += WarpSize) { + MathT val = 0; + val += wc * centers[j + dim * li]; + val += wd * mapping_op(dataset[j + dim * i]); + val /= wc + wd; + centers[j + dim * l] = val; + } +} + +/** + * @brief Adjust centers for clusters that have small number of entries. + * + * For each cluster, where the cluster size is not bigger than a threshold, the center is moved + * towards a data point that belongs to a large cluster. + * + * NB: if this function returns `true`, you should update the labels. + * + * NB: all pointers must be on the device side. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam CounterT counter type supported by CUDA's native atomicAdd + * @tparam MappingOpT type of the mapping operation + * + * @param[inout] centers cluster centers [n_clusters, dim] + * @param[in] n_clusters number of rows in `centers` + * @param[in] dim number of columns in `centers` and `dataset` + * @param[in] dataset a host pointer to the row-major data matrix [n_rows, dim] + * @param[in] n_rows number of rows in `dataset` + * @param[in] labels a host pointer to the cluster indices [n_rows] + * @param[in] cluster_sizes number of rows in each cluster [n_clusters] + * @param[in] threshold defines a criterion for adjusting a cluster + * (cluster_sizes <= average_size * threshold) + * 0 <= threshold < 1 + * @param[in] mapping_op Mapping operation from T to MathT + * @param[in] stream CUDA stream + * @param[inout] device_memory memory resource to use for temporary allocations + * + * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). + */ +template +auto adjust_centers(MathT* centers, + IdxT n_clusters, + IdxT dim, + const T* dataset, + IdxT n_rows, + const LabelT* labels, + const CounterT* cluster_sizes, + MathT threshold, + MappingOpT mapping_op, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* device_memory) -> bool +{ + common::nvtx::range fun_scope( + "adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); + if (n_clusters == 0) { return false; } + constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, + 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, + 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, + 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; + static IdxT i = 0; + static IdxT i_primes = 0; + + bool adjusted = false; + IdxT average = n_rows / n_clusters; + IdxT ofst; + do { + i_primes = (i_primes + 1) % kPrimes.size(); + ofst = kPrimes[i_primes]; + } while (n_rows % ofst == 0); + + constexpr uint32_t kBlockDimY = 4; + const dim3 block_dim(WarpSize, kBlockDimY, 1); + const dim3 grid_dim(1, raft::ceildiv(n_clusters, static_cast(kBlockDimY)), 1); + rmm::device_scalar update_count(0, stream, device_memory); + adjust_centers_kernel<<>>(centers, + n_clusters, + dim, + dataset, + n_rows, + labels, + cluster_sizes, + threshold, + average, + ofst, + update_count.data(), + mapping_op); + adjusted = update_count.value(stream) > 0; // NB: rmm scalar performs the sync + + return adjusted; +} + +/** + * @brief Expectation-maximization-balancing combined in an iterative process. + * + * Note, the `cluster_centers` is assumed to be already initialized here. + * Thus, this function can be used for fine-tuning existing clusters; + * to train from scratch, use `build_clusters` function below. + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam CounterT counter type supported by CUDA's native atomicAdd + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle + * @param[in] params Structure containing the hyper-parameters + * @param[in] n_iters Requested number of iterations (can differ from params.n_iter!) + * @param[in] dim Dimensionality of the dataset + * @param[in] dataset Pointer to a managed row-major array [n_rows, dim] + * @param[in] dataset_norm Pointer to the precomputed norm (for L2 metrics only) [n_rows] + * @param[in] n_rows Number of rows in the dataset + * @param[in] n_cluster Requested number of clusters + * @param[inout] cluster_centers Pointer to a managed row-major array [n_clusters, dim] + * @param[out] cluster_labels Pointer to a managed row-major array [n_rows] + * @param[out] cluster_sizes Pointer to a managed row-major array [n_clusters] + * @param[in] balancing_pullback + * if the cluster centers are rebalanced on this number of iterations, + * one extra iteration is performed (this could happen several times) (default should be `2`). + * In other words, the first and then every `ballancing_pullback`-th rebalancing operation adds + * one more iteration to the main cycle. + * @param[in] balancing_threshold + * the rebalancing takes place if any cluster is smaller than `avg_size * balancing_threshold` + * on a given iteration (default should be `~ 0.25`). + * @param[in] mapping_op Mapping operation from T to MathT + * @param[inout] device_memory + * A memory resource for device allocations (makes sense to provide a memory pool here) + */ +template +void balancing_em_iters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + uint32_t n_iters, + IdxT dim, + const T* dataset, + const MathT* dataset_norm, + IdxT n_rows, + IdxT n_clusters, + MathT* cluster_centers, + LabelT* cluster_labels, + CounterT* cluster_sizes, + uint32_t balancing_pullback, + MathT balancing_threshold, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* device_memory) +{ + auto stream = handle.get_stream(); + uint32_t balancing_counter = balancing_pullback; + for (uint32_t iter = 0; iter < n_iters; iter++) { + // Balancing step - move the centers around to equalize cluster sizes + // (but not on the first iteration) + if (iter > 0 && adjust_centers(cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + cluster_sizes, + balancing_threshold, + mapping_op, + stream, + device_memory)) { + if (balancing_counter++ >= balancing_pullback) { + balancing_counter -= balancing_pullback; + n_iters++; + } + } + switch (params.metric) { + // For some metrics, cluster calculation and adjustment tends to favor zero center vectors. + // To avoid converging to zero, we normalize the center vectors on every iteration. + case raft::distance::DistanceType::InnerProduct: + case raft::distance::DistanceType::CosineExpanded: + case raft::distance::DistanceType::CorrelationExpanded: { + auto clusters_in_view = raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + auto clusters_out_view = raft::make_device_matrix_view( + cluster_centers, n_clusters, dim); + raft::linalg::row_normalize( + handle, clusters_in_view, clusters_out_view, raft::linalg::L2Norm); + break; + } + default: break; + } + // E: Expectation step - predict labels + predict(handle, + params, + cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + mapping_op, + device_memory, + dataset_norm); + // M: Maximization step - calculate optimal cluster centers + calc_centers_and_sizes(handle, + cluster_centers, + cluster_sizes, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + true, + mapping_op, + device_memory); + } +} + +/** Randomly initialize cluster centers and then call `balancing_em_iters`. */ +template +void build_clusters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + IdxT dim, + const T* dataset, + IdxT n_rows, + IdxT n_clusters, + MathT* cluster_centers, + LabelT* cluster_labels, + CounterT* cluster_sizes, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* device_memory, + const MathT* dataset_norm = nullptr) +{ + auto stream = handle.get_stream(); + + // "randomly" initialize labels + auto labels_view = raft::make_device_vector_view(cluster_labels, n_rows); + linalg::map_offset( + handle, + labels_view, + raft::compose_op(raft::cast_op(), raft::mod_const_op(n_clusters))); + + // update centers to match the initialized labels. + calc_centers_and_sizes(handle, + cluster_centers, + cluster_sizes, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + true, + mapping_op, + device_memory); + + // run EM + balancing_em_iters(handle, + params, + params.n_iters, + dim, + dataset, + dataset_norm, + n_rows, + n_clusters, + cluster_centers, + cluster_labels, + cluster_sizes, + 2, + MathT{0.25}, + mapping_op, + device_memory); +} + +/** Calculate how many fine clusters should belong to each mesocluster. */ +template +inline auto arrange_fine_clusters(IdxT n_clusters, + IdxT n_mesoclusters, + IdxT n_rows, + const CounterT* mesocluster_sizes) +{ + std::vector fine_clusters_nums(n_mesoclusters); + std::vector fine_clusters_csum(n_mesoclusters + 1); + fine_clusters_csum[0] = 0; + + IdxT n_lists_rem = n_clusters; + IdxT n_nonempty_ms_rem = 0; + for (IdxT i = 0; i < n_mesoclusters; i++) { + n_nonempty_ms_rem += mesocluster_sizes[i] > CounterT{0} ? 1 : 0; + } + IdxT n_rows_rem = n_rows; + CounterT mesocluster_size_sum = 0; + CounterT mesocluster_size_max = 0; + IdxT fine_clusters_nums_max = 0; + for (IdxT i = 0; i < n_mesoclusters; i++) { + if (i < n_mesoclusters - 1) { + // Although the algorithm is meant to produce balanced clusters, when something + // goes wrong, we may get empty clusters (e.g. during development/debugging). + // The code below ensures a proportional arrangement of fine cluster numbers + // per mesocluster, even if some clusters are empty. + if (mesocluster_sizes[i] == 0) { + fine_clusters_nums[i] = 0; + } else { + n_nonempty_ms_rem--; + auto s = static_cast( + static_cast(n_lists_rem * mesocluster_sizes[i]) / n_rows_rem + .5); + s = std::min(s, n_lists_rem - n_nonempty_ms_rem); + fine_clusters_nums[i] = std::max(s, IdxT{1}); + } + } else { + fine_clusters_nums[i] = n_lists_rem; + } + n_lists_rem -= fine_clusters_nums[i]; + n_rows_rem -= mesocluster_sizes[i]; + mesocluster_size_max = max(mesocluster_size_max, mesocluster_sizes[i]); + mesocluster_size_sum += mesocluster_sizes[i]; + fine_clusters_nums_max = max(fine_clusters_nums_max, fine_clusters_nums[i]); + fine_clusters_csum[i + 1] = fine_clusters_csum[i] + fine_clusters_nums[i]; + } + + RAFT_EXPECTS(static_cast(mesocluster_size_sum) == n_rows, + "mesocluster sizes do not add up (%zu) to the total trainset size (%zu)", + static_cast(mesocluster_size_sum), + static_cast(n_rows)); + RAFT_EXPECTS(fine_clusters_csum[n_mesoclusters] == n_clusters, + "fine cluster numbers do not add up (%zu) to the total number of clusters (%zu)", + static_cast(fine_clusters_csum[n_mesoclusters]), + static_cast(n_clusters)); + + return std::make_tuple(static_cast(mesocluster_size_max), + fine_clusters_nums_max, + std::move(fine_clusters_nums), + std::move(fine_clusters_csum)); +} + +/** + * Given the (coarse) mesoclusters and the distribution of fine clusters within them, + * build the fine clusters. + * + * Processing one mesocluster at a time: + * 1. Copy mesocluster data into a separate buffer + * 2. Predict fine cluster + * 3. Refince the fine cluster centers + * + * As a result, the fine clusters are what is returned by `build_hierarchical`; + * this function returns the total number of fine clusters, which can be checked to be + * the same as the requested number of clusters. + * + * Note: this function uses at most `fine_clusters_nums_max` points per mesocluster for training; + * if one of the clusters is larger than that (as given by `mesocluster_sizes`), the extra data + * is ignored and a warning is reported. + */ +template +auto build_fine_clusters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + IdxT dim, + const T* dataset_mptr, + const MathT* dataset_norm_mptr, + const LabelT* labels_mptr, + IdxT n_rows, + const IdxT* fine_clusters_nums, + const IdxT* fine_clusters_csum, + const CounterT* mesocluster_sizes, + IdxT n_mesoclusters, + IdxT mesocluster_size_max, + IdxT fine_clusters_nums_max, + MathT* cluster_centers, + MappingOpT mapping_op, + rmm::mr::device_memory_resource* managed_memory, + rmm::mr::device_memory_resource* device_memory) -> IdxT +{ + auto stream = handle.get_stream(); + rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); + rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); + rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); + auto mc_trainset_ids = mc_trainset_ids_buf.data(); + auto mc_trainset = mc_trainset_buf.data(); + auto mc_trainset_norm = mc_trainset_norm_buf.data(); + + // label (cluster ID) of each vector + rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); + + rmm::device_uvector mc_trainset_ccenters( + fine_clusters_nums_max * dim, stream, device_memory); + // number of vectors in each cluster + rmm::device_uvector mc_trainset_csizes_tmp( + fine_clusters_nums_max, stream, device_memory); + + // Training clusters in each meso-cluster + IdxT n_clusters_done = 0; + for (IdxT i = 0; i < n_mesoclusters; i++) { + IdxT k = 0; + for (IdxT j = 0; j < n_rows && k < mesocluster_size_max; j++) { + if (labels_mptr[j] == LabelT(i)) { mc_trainset_ids[k++] = j; } + } + if (k != static_cast(mesocluster_sizes[i])) + RAFT_LOG_WARN("Incorrect mesocluster size at %d. %zu vs %zu", + static_cast(i), + static_cast(k), + static_cast(mesocluster_sizes[i])); + if (k == 0) { + RAFT_LOG_DEBUG("Empty cluster %d", i); + RAFT_EXPECTS(fine_clusters_nums[i] == 0, + "Number of fine clusters must be zero for the empty mesocluster (got %d)", + static_cast(fine_clusters_nums[i])); + continue; + } else { + RAFT_EXPECTS(fine_clusters_nums[i] > 0, + "Number of fine clusters must be non-zero for a non-empty mesocluster"); + } + + cub::TransformInputIterator mapping_itr(dataset_mptr, mapping_op); + raft::matrix::gather(mapping_itr, dim, n_rows, mc_trainset_ids, k, mc_trainset, stream); + if (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded) { + thrust::gather(handle.get_thrust_policy(), + mc_trainset_ids, + mc_trainset_ids + k, + dataset_norm_mptr, + mc_trainset_norm); + } + + build_clusters(handle, + params, + dim, + mc_trainset, + k, + fine_clusters_nums[i], + mc_trainset_ccenters.data(), + mc_trainset_labels.data(), + mc_trainset_csizes_tmp.data(), + mapping_op, + device_memory, + mc_trainset_norm); + + raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), + mc_trainset_ccenters.data(), + fine_clusters_nums[i] * dim, + stream); + handle.sync_stream(stream); + n_clusters_done += fine_clusters_nums[i]; + } + return n_clusters_done; +} + +/** + * @brief Hierarchical balanced k-means + * + * @tparam T element type + * @tparam MathT type of the centroids and mapped data + * @tparam IdxT index type + * @tparam LabelT label type + * @tparam MappingOpT type of the mapping operation + * + * @param[in] handle The raft handle. + * @param[in] params Structure containing the hyper-parameters + * @param dim number of columns in `centers` and `dataset` + * @param[in] dataset a device pointer to the source dataset [n_rows, dim] + * @param n_rows number of rows in the input + * @param[out] cluster_centers a device pointer to the found cluster centers [n_cluster, dim] + * @param n_cluster + * @param metric the distance type + * @param mapping_op Mapping operation from T to MathT + * @param stream + */ +template +void build_hierarchical(const raft::device_resources& handle, + const kmeans_balanced_params& params, + IdxT dim, + const T* dataset, + IdxT n_rows, + MathT* cluster_centers, + IdxT n_clusters, + MappingOpT mapping_op) +{ + auto stream = handle.get_stream(); + using LabelT = uint32_t; + + common::nvtx::range fun_scope( + "build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); + + IdxT n_mesoclusters = std::min(n_clusters, static_cast(std::sqrt(n_clusters) + 0.5)); + RAFT_LOG_DEBUG("build_hierarchical: n_mesoclusters: %u", n_mesoclusters); + + rmm::mr::managed_memory_resource managed_memory; + rmm::mr::device_memory_resource* device_memory = handle.get_workspace_resource(); + auto [max_minibatch_size, mem_per_row] = + calc_minibatch_size(n_clusters, n_rows, dim, params.metric, std::is_same_v); + auto pool_guard = + raft::get_pool_memory_resource(device_memory, mem_per_row * size_t(max_minibatch_size)); + if (pool_guard) { + RAFT_LOG_DEBUG("build_hierarchical: using pool memory resource with initial size %zu bytes", + pool_guard->pool_size()); + } + + // Precompute the L2 norm of the dataset if relevant. + const MathT* dataset_norm = nullptr; + rmm::device_uvector dataset_norm_buf(0, stream, device_memory); + if (params.metric == raft::distance::DistanceType::L2Expanded || + params.metric == raft::distance::DistanceType::L2SqrtExpanded) { + dataset_norm_buf.resize(n_rows, stream); + for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { + IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); + compute_norm(handle, + dataset_norm_buf.data() + offset, + dataset + dim * offset, + dim, + minibatch_size, + mapping_op, + device_memory); + } + dataset_norm = (const MathT*)dataset_norm_buf.data(); + } + + /* Temporary workaround to cub::DeviceHistogram not supporting any type that isn't natively + * supported by atomicAdd: find a supported CounterT based on the IdxT. */ + typedef typename std::conditional_t + CounterT; + + // build coarse clusters (mesoclusters) + rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); + rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); + { + rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); + build_clusters(handle, + params, + dim, + dataset, + n_rows, + n_mesoclusters, + mesocluster_centers_buf.data(), + mesocluster_labels_buf.data(), + mesocluster_sizes_buf.data(), + mapping_op, + device_memory, + dataset_norm); + } + + auto mesocluster_sizes = mesocluster_sizes_buf.data(); + auto mesocluster_labels = mesocluster_labels_buf.data(); + + handle.sync_stream(stream); + + // build fine clusters + auto [mesocluster_size_max, fine_clusters_nums_max, fine_clusters_nums, fine_clusters_csum] = + arrange_fine_clusters(n_clusters, n_mesoclusters, n_rows, mesocluster_sizes); + + const IdxT mesocluster_size_max_balanced = div_rounding_up_safe( + 2lu * size_t(n_rows), std::max(size_t(n_mesoclusters), 1lu)); + if (mesocluster_size_max > mesocluster_size_max_balanced) { + RAFT_LOG_WARN( + "build_hierarchical: built unbalanced mesoclusters (max_mesocluster_size == %u > %u). " + "At most %u points will be used for training within each mesocluster. " + "Consider increasing the number of training iterations `n_iters`.", + mesocluster_size_max, + mesocluster_size_max_balanced, + mesocluster_size_max_balanced); + RAFT_LOG_TRACE_VEC(mesocluster_sizes, n_mesoclusters); + RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); + mesocluster_size_max = mesocluster_size_max_balanced; + } + + auto n_clusters_done = build_fine_clusters(handle, + params, + dim, + dataset, + dataset_norm, + mesocluster_labels, + n_rows, + fine_clusters_nums.data(), + fine_clusters_csum.data(), + mesocluster_sizes, + n_mesoclusters, + mesocluster_size_max, + fine_clusters_nums_max, + cluster_centers, + mapping_op, + &managed_memory, + device_memory); + RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); + + rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); + rmm::device_uvector labels(n_rows, stream, device_memory); + + // Fine-tuning k-means for all clusters + // + // (*) Since the likely cluster centroids have been calculated hierarchically already, the number + // of iterations for fine-tuning kmeans for whole clusters should be reduced. However, there is a + // possibility that the clusters could be unbalanced here, in which case the actual number of + // iterations would be increased. + // + balancing_em_iters(handle, + params, + std::max(params.n_iters / 10, 2), + dim, + dataset, + dataset_norm, + n_rows, + n_clusters, + cluster_centers, + labels.data(), + cluster_sizes.data(), + 5, + MathT{0.2}, + mapping_op, + device_memory); +} + +} // namespace raft::cluster::detail diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 2fd33ac759..76fc22e99e 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,7 +30,7 @@ #include #include #include -#include +#include #include #include #include @@ -88,7 +88,7 @@ struct KeyValueIndexOp { // Computes the intensity histogram from a sequence of labels template -void countLabels(const raft::handle_t& handle, +void countLabels(raft::device_resources const& handle, SampleIteratorT labels, CounterT* count, IndexT n_samples, @@ -96,9 +96,13 @@ void countLabels(const raft::handle_t& handle, rmm::device_uvector& workspace) { cudaStream_t stream = handle.get_stream(); - IndexT num_levels = n_clusters + 1; - IndexT lower_level = 0; - IndexT upper_level = n_clusters; + + // CUB::DeviceHistogram requires a signed index type + typedef typename std::make_signed_t CubIndexT; + + CubIndexT num_levels = n_clusters + 1; + CubIndexT lower_level = 0; + CubIndexT upper_level = n_clusters; size_t temp_storage_bytes = 0; RAFT_CUDA_TRY(cub::DeviceHistogram::HistogramEven(nullptr, @@ -108,7 +112,7 @@ void countLabels(const raft::handle_t& handle, num_levels, lower_level, upper_level, - n_samples, + static_cast(n_samples), stream)); workspace.resize(temp_storage_bytes, stream); @@ -120,12 +124,12 @@ void countLabels(const raft::handle_t& handle, num_levels, lower_level, upper_level, - n_samples, + static_cast(n_samples), stream)); } template -void checkWeight(const raft::handle_t& handle, +void checkWeight(raft::device_resources const& handle, raft::device_vector_view weight, rmm::device_uvector& workspace) { @@ -183,7 +187,7 @@ template -void computeClusterCost(const raft::handle_t& handle, +void computeClusterCost(raft::device_resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, @@ -218,7 +222,7 @@ void computeClusterCost(const raft::handle_t& handle, } template -void sampleCentroids(const raft::handle_t& handle, +void sampleCentroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -282,7 +286,7 @@ void sampleCentroids(const raft::handle_t& handle, // calculate pairwise distance between 'dataset[n x d]' and 'centroids[k x d]', // result will be stored in 'pairwiseDistance[n x k]' template -void pairwise_distance_kmeans(const raft::handle_t& handle, +void pairwise_distance_kmeans(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_matrix_view pairwiseDistance, @@ -310,7 +314,7 @@ void pairwise_distance_kmeans(const raft::handle_t& handle, // shuffle and randomly select 'n_samples_to_gather' from input 'in' and stores // in 'out' does not modify the input template -void shuffleAndGather(const raft::handle_t& handle, +void shuffleAndGather(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -335,7 +339,7 @@ void shuffleAndGather(const raft::handle_t& handle, in.extent(1), in.extent(0), indices.data_handle(), - n_samples_to_gather, + static_cast(n_samples_to_gather), out.data_handle(), stream); } @@ -345,7 +349,7 @@ void shuffleAndGather(const raft::handle_t& handle, // is the distance between the sample and the 'centroid[key]' template void minClusterAndDistanceCompute( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view, IndexT> minClusterAndDistance, @@ -478,7 +482,7 @@ void minClusterAndDistanceCompute( } template -void minClusterDistanceCompute(const raft::handle_t& handle, +void minClusterDistanceCompute(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view minClusterDistance, @@ -596,7 +600,7 @@ void minClusterDistanceCompute(const raft::handle_t& handle, } template -void countSamplesInCluster(const raft::handle_t& handle, +void countSamplesInCluster(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, diff --git a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh index 2746b6f657..a9d8777304 100644 --- a/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_deprecated.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ #include #include -#include +#include #include #include #include @@ -360,7 +360,7 @@ static __global__ void divideCentroids(index_type_t d, * @return Zero if successful. Otherwise non-zero. */ template -static int chooseNewCentroid(handle_t const& handle, +static int chooseNewCentroid(raft::device_resources const& handle, index_type_t n, index_type_t d, value_type_t rand, @@ -457,7 +457,7 @@ static int chooseNewCentroid(handle_t const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int initializeCentroids(handle_t const& handle, +static int initializeCentroids(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -568,7 +568,7 @@ static int initializeCentroids(handle_t const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int assignCentroids(handle_t const& handle, +static int assignCentroids(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -640,7 +640,7 @@ static int assignCentroids(handle_t const& handle, * @return Zero if successful. Otherwise non-zero. */ template -static int updateCentroids(handle_t const& handle, +static int updateCentroids(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -783,7 +783,7 @@ static int updateCentroids(handle_t const& handle, * @return error flag. */ template -int kmeans(handle_t const& handle, +int kmeans(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, @@ -950,7 +950,7 @@ int kmeans(handle_t const& handle, * @return error flag */ template -int kmeans(handle_t const& handle, +int kmeans(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh index 8143d21641..46e31b672e 100644 --- a/cpp/include/raft/cluster/detail/mst.cuh +++ b/cpp/include/raft/cluster/detail/mst.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -67,7 +67,7 @@ void merge_msts(sparse::solver::Graph_COO& coo1, */ template void connect_knn_graph( - const raft::handle_t& handle, + raft::device_resources const& handle, const value_t* X, sparse::solver::Graph_COO& msf, size_t m, @@ -130,7 +130,7 @@ void connect_knn_graph( */ template void build_sorted_mst( - const raft::handle_t& handle, + raft::device_resources const& handle, const value_t* X, const value_idx* indptr, const value_idx* indices, diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh index d12db85e1b..473d858827 100644 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ b/cpp/include/raft/cluster/detail/single_linkage.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -49,7 +49,7 @@ static const size_t EMPTY = 0; * @param[in] n_clusters number of clusters to assign data samples */ template -void single_linkage(const raft::handle_t& handle, +void single_linkage(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, diff --git a/cpp/include/raft/cluster/kmeans.cuh b/cpp/include/raft/cluster/kmeans.cuh index 4b912dc966..ac9e66d5da 100644 --- a/cpp/include/raft/cluster/kmeans.cuh +++ b/cpp/include/raft/cluster/kmeans.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -44,12 +44,12 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * k-means++ algorithm. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::handle_t handle; + * raft::raft::device_resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -83,7 +83,7 @@ using KeyValueIndexOp = detail::KeyValueIndexOp; * @param[out] n_iter Number of iterations run. */ template -void fit(handle_t const& handle, +void fit(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -98,12 +98,12 @@ void fit(handle_t const& handle, * @brief Predict the closest cluster each sample in X belongs to. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::handle_t handle; + * raft::raft::device_resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -147,7 +147,7 @@ void fit(handle_t const& handle, * their closest cluster center. */ template -void predict(handle_t const& handle, +void predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -165,12 +165,12 @@ void predict(handle_t const& handle, * in the input. * * @code{.cpp} - * #include + * #include * #include * #include * using namespace raft::cluster; * ... - * raft::handle_t handle; + * raft::raft::device_resources handle; * raft::cluster::KMeansParams params; * int n_features = 15, inertia, n_iter; * auto centroids = raft::make_device_matrix(handle, params.n_clusters, n_features); @@ -210,7 +210,7 @@ void predict(handle_t const& handle, * @param[out] n_iter Number of iterations run. */ template -void fit_predict(handle_t const& handle, +void fit_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -239,7 +239,7 @@ void fit_predict(handle_t const& handle, * [dim = n_samples x n_features] */ template -void transform(const raft::handle_t& handle, +void transform(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -249,7 +249,7 @@ void transform(const raft::handle_t& handle, } template -void transform(const raft::handle_t& handle, +void transform(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, @@ -281,7 +281,7 @@ void transform(const raft::handle_t& handle, * */ template -void sample_centroids(const raft::handle_t& handle, +void sample_centroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -308,7 +308,7 @@ void sample_centroids(const raft::handle_t& handle, * */ template -void cluster_cost(const raft::handle_t& handle, +void cluster_cost(raft::device_resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, @@ -334,7 +334,7 @@ void cluster_cost(const raft::handle_t& handle, * @param[out] new_centroids: output matrix of updated centroids (size n_clusters, n_features) */ template -void update_centroids(const raft::handle_t& handle, +void update_centroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view sample_weights, raft::device_matrix_view centroids, @@ -375,7 +375,7 @@ void update_centroids(const raft::handle_t& handle, * */ template -void min_cluster_distance(const raft::handle_t& handle, +void min_cluster_distance(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view minClusterDistance, @@ -426,7 +426,7 @@ void min_cluster_distance(const raft::handle_t& handle, */ template void min_cluster_and_distance( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, raft::device_vector_view, IndexT> minClusterAndDistance, @@ -466,7 +466,7 @@ void min_cluster_and_distance( * */ template -void shuffle_and_gather(const raft::handle_t& handle, +void shuffle_and_gather(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -495,7 +495,7 @@ void shuffle_and_gather(const raft::handle_t& handle, * */ template -void count_samples_in_cluster(const raft::handle_t& handle, +void count_samples_in_cluster(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -525,7 +525,7 @@ void count_samples_in_cluster(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void init_plus_plus(const raft::handle_t& handle, +void init_plus_plus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -558,7 +558,7 @@ void init_plus_plus(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void fit_main(const raft::handle_t& handle, +void fit_main(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view sample_weights, @@ -605,7 +605,7 @@ namespace raft::cluster { * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -617,7 +617,7 @@ void kmeans_fit(handle_t const& handle, } template -void kmeans_fit(handle_t const& handle, +void kmeans_fit(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -652,7 +652,7 @@ void kmeans_fit(handle_t const& handle, * their closest cluster center. */ template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -666,7 +666,7 @@ void kmeans_predict(handle_t const& handle, } template -void kmeans_predict(handle_t const& handle, +void kmeans_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -717,7 +717,7 @@ void kmeans_predict(handle_t const& handle, * @param[out] n_iter Number of iterations run. */ template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, std::optional> sample_weight, @@ -731,7 +731,7 @@ void kmeans_fit_predict(handle_t const& handle, } template -void kmeans_fit_predict(handle_t const& handle, +void kmeans_fit_predict(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* sample_weight, @@ -762,7 +762,7 @@ void kmeans_fit_predict(handle_t const& handle, * [dim = n_samples x n_features] */ template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -772,7 +772,7 @@ void kmeans_transform(const raft::handle_t& handle, } template -void kmeans_transform(const raft::handle_t& handle, +void kmeans_transform(raft::device_resources const& handle, const KMeansParams& params, const DataT* X, const DataT* centroids, @@ -809,7 +809,7 @@ using KeyValueIndexOp = kmeans::KeyValueIndexOp; * */ template -void sampleCentroids(const raft::handle_t& handle, +void sampleCentroids(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_vector_view minClusterDistance, raft::device_vector_view isSampleCentroid, @@ -836,7 +836,7 @@ void sampleCentroids(const raft::handle_t& handle, * */ template -void computeClusterCost(const raft::handle_t& handle, +void computeClusterCost(raft::device_resources const& handle, raft::device_vector_view minClusterDistance, rmm::device_uvector& workspace, raft::device_scalar_view clusterCost, @@ -867,7 +867,7 @@ void computeClusterCost(const raft::handle_t& handle, * */ template -void minClusterDistanceCompute(const raft::handle_t& handle, +void minClusterDistanceCompute(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -914,7 +914,7 @@ void minClusterDistanceCompute(const raft::handle_t& handle, */ template void minClusterAndDistanceCompute( - const raft::handle_t& handle, + raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroids, @@ -952,7 +952,7 @@ void minClusterAndDistanceCompute( * */ template -void shuffleAndGather(const raft::handle_t& handle, +void shuffleAndGather(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, uint32_t n_samples_to_gather, @@ -981,7 +981,7 @@ void shuffleAndGather(const raft::handle_t& handle, * */ template -void countSamplesInCluster(const raft::handle_t& handle, +void countSamplesInCluster(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view L2NormX, @@ -1012,7 +1012,7 @@ void countSamplesInCluster(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void kmeansPlusPlus(const raft::handle_t& handle, +void kmeansPlusPlus(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, @@ -1045,7 +1045,7 @@ void kmeansPlusPlus(const raft::handle_t& handle, * @param[in] workspace Temporary workspace buffer which can get resized */ template -void kmeans_fit_main(const raft::handle_t& handle, +void kmeans_fit_main(raft::device_resources const& handle, const KMeansParams& params, raft::device_matrix_view X, raft::device_vector_view weight, diff --git a/cpp/include/raft/cluster/kmeans_balanced.cuh b/cpp/include/raft/cluster/kmeans_balanced.cuh new file mode 100644 index 0000000000..405c7a8018 --- /dev/null +++ b/cpp/include/raft/cluster/kmeans_balanced.cuh @@ -0,0 +1,365 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include +#include +#include + +namespace raft::cluster::kmeans_balanced { + +/** + * @brief Find clusters of balanced sizes with a hierarchical k-means algorithm. + * + * This variant of the k-means algorithm first clusters the dataset in mesoclusters, then clusters + * the subsets associated to each mesocluster into fine clusters, and finally runs a few k-means + * iterations over the whole dataset and with all the centroids to obtain the final clusters. + * + * Each k-means iteration applies expectation-maximization-balancing: + * - Balancing: adjust centers for clusters that have a small number of entries. If the size of a + * cluster is below a threshold, the center is moved towards a bigger cluster. + * - Expectation: predict the labels (i.e find closest cluster centroid to each point) + * - Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster) + * + * The number of mesoclusters is chosen by rounding the square root of the number of clusters. E.g + * for 512 clusters, we would have 23 mesoclusters. The number of fine clusters per mesocluster is + * chosen proportionally to the number of points in each mesocluster. + * + * This variant of k-means uses random initialization and a fixed number of iterations, though + * iterations can be repeated if the balancing step moved the centroids. + * + * Additionally, this algorithm supports quantized datasets in arbitrary types but the core part of + * the algorithm will work with a floating-point type, hence a conversion function can be provided + * to map the data type to the math type. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * raft::cluster::kmeans_balanced::fit(handle, params, X, centroids.view()); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Training instances to cluster. The data must be in row-major format. + * [dim = n_samples x n_features] + * @param[out] centroids The generated centroids [dim = n_clusters x n_features] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic + * datatype. If DataT == MathT, this must be the identity. + */ +template +void fit(const raft::device_resources& handle, + kmeans_balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= + static_cast(std::numeric_limits::max()), + "The chosen index type cannot represent all indices for the given dataset"); + RAFT_EXPECTS(centroids.extent(0) > IndexT{0} && centroids.extent(0) <= X.extent(0), + "The number of centroids must be strictly positive and cannot exceed the number of " + "points in the training dataset."); + + detail::build_hierarchical(handle, + params, + X.extent(1), + X.data_handle(), + X.extent(0), + centroids.data_handle(), + centroids.extent(0), + mapping_op); +} + +/** + * @brief Predict the closest cluster each sample in X belongs to. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto labels = raft::make_device_vector(handle, n_rows); + * raft::cluster::kmeans_balanced::predict(handle, params, X, centroids, labels); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Dataset for which to infer the closest clusters. + * [dim = n_samples x n_features] + * @param[in] centroids The input centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic + * datatype. If DataT == MathT, this must be the identity. + */ +template +void predict(const raft::device_resources& handle, + kmeans_balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(static_cast(X.extent(0)) * static_cast(X.extent(1)) <= + static_cast(std::numeric_limits::max()), + "The chosen index type cannot represent all indices for the given dataset"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) <= + static_cast(std::numeric_limits::max()), + "The chosen label type cannot represent all cluster labels"); + + detail::predict(handle, + params, + centroids.data_handle(), + centroids.extent(0), + X.extent(1), + X.data_handle(), + X.extent(0), + labels.data_handle(), + mapping_op); +} + +/** + * @brief Compute hierarchical balanced k-means clustering and predict cluster index for each sample + * in the input. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, n_rows); + * raft::cluster::kmeans_balanced::fit_predict( + * handle, params, X, centroids.view(), labels.view()); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Training instances to cluster. The data must be in row-major format. + * [dim = n_samples x n_features] + * @param[out] centroids The output centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the arithmetic + * datatype. If DataT and MathT are the same, this must be the identity. + */ +template +void fit_predict(const raft::device_resources& handle, + kmeans_balanced_params const& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + MappingOpT mapping_op = raft::identity_op()) +{ + auto centroids_const = raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)); + raft::cluster::kmeans_balanced::fit(handle, params, X, centroids, mapping_op); + raft::cluster::kmeans_balanced::predict(handle, params, X, centroids_const, labels, mapping_op); +} + +namespace helpers { + +/** + * @brief Randomly initialize centers and apply expectation-maximization-balancing iterations + * + * This is essentially the non-hierarchical balanced k-means algorithm which is used by the + * hierarchical algorithm once to build the mesoclusters and once per mesocluster to build the fine + * clusters. + * + * @code{.cpp} + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * raft::cluster::kmeans_balanced_params params; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto labels = raft::make_device_vector(handle, n_samples); + * auto sizes = raft::make_device_vector(handle, n_clusters); + * raft::cluster::kmeans_balanced::build_clusters( + * handle, params, X, centroids.view(), labels.view(), sizes.view()); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam CounterT Counter type supported by CUDA's native atomicAdd. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] params Structure containing the hyper-parameters + * @param[in] X Training instances to cluster. The data must be in row-major format. + * [dim = n_samples x n_features] + * @param[out] centroids The output centroids [dim = n_clusters x n_features] + * @param[out] labels The output labels [dim = n_samples] + * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the + * arithmetic datatype. If DataT == MathT, this must be the identity. + * @param[in] X_norm (optional) Dataset's row norms [dim = n_samples] + */ +template +void build_clusters(const raft::device_resources& handle, + const kmeans_balanced_params& params, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view labels, + raft::device_vector_view cluster_sizes, + MappingOpT mapping_op = raft::identity_op(), + std::optional> X_norm = std::nullopt) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), + "Number of rows in centroids and clusyer_sizes are different"); + + detail::build_clusters(handle, + params, + X.extent(1), + X.data_handle(), + X.extent(0), + centroids.extent(0), + centroids.data_handle(), + labels.data_handle(), + cluster_sizes.data_handle(), + mapping_op, + handle.get_workspace_resource(), + X_norm.has_value() ? X_norm.value().data_handle() : nullptr); +} + +/** + * @brief Given the data and labels, calculate cluster centers and sizes in one sweep. + * + * Let `S_i = {x_k | x_k \in X & labels[k] == i}` be the vectors in the dataset with label i. + * + * On exit, + * `centers_i = (\sum_{x \in S_i} x + w_i * center_i) / (|S_i| + w_i)`, + * where `w_i = reset_counters ? 0 : cluster_size[i]`. + * + * In other words, the updated cluster centers are a weighted average of the existing cluster + * center, and the coordinates of the points labeled with i. _This allows calling this function + * multiple times with different datasets with the same effect as if calling this function once + * on the combined dataset_. + * + * @code{.cpp} + * #include + * #include + * ... + * raft::handle_t handle; + * auto centroids = raft::make_device_matrix(handle, n_clusters, n_features); + * auto sizes = raft::make_device_vector(handle, n_clusters); + * raft::cluster::kmeans_balanced::calc_centers_and_sizes( + * handle, X, labels, centroids.view(), sizes.view(), true); + * @endcode + * + * @tparam DataT Type of the input data. + * @tparam MathT Type of the centroids and mapped data. + * @tparam IndexT Type used for indexing. + * @tparam LabelT Type of the output labels. + * @tparam CounterT Counter type supported by CUDA's native atomicAdd. + * @tparam MappingOpT Type of the mapping function. + * @param[in] handle The raft resources + * @param[in] X Dataset for which to calculate cluster centers. The data must be in + * row-major format. [dim = n_samples x n_features] + * @param[in] labels The input labels [dim = n_samples] + * @param[out] centroids The output centroids [dim = n_clusters x n_features] + * @param[out] cluster_sizes Size of each cluster [dim = n_clusters] + * @param[in] reset_counters Whether to clear the output arrays before calculating. + * When set to `false`, this function may be used to update existing + * centers and sizes using the weighted average principle. + * @param[in] mapping_op (optional) Functor to convert from the input datatype to the + * arithmetic datatype. If DataT == MathT, this must be the identity. + */ +template +void calc_centers_and_sizes(const raft::device_resources& handle, + raft::device_matrix_view X, + raft::device_vector_view labels, + raft::device_matrix_view centroids, + raft::device_vector_view cluster_sizes, + bool reset_counters = true, + MappingOpT mapping_op = raft::identity_op()) +{ + RAFT_EXPECTS(X.extent(0) == labels.extent(0), + "Number of rows in dataset and labels are different"); + RAFT_EXPECTS(X.extent(1) == centroids.extent(1), + "Number of features in dataset and centroids are different"); + RAFT_EXPECTS(centroids.extent(0) == cluster_sizes.extent(0), + "Number of rows in centroids and clusyer_sizes are different"); + + detail::calc_centers_and_sizes(handle, + centroids.data_handle(), + cluster_sizes.data_handle(), + centroids.extent(0), + X.extent(1), + X.data_handle(), + X.extent(0), + labels.data_handle(), + reset_counters, + mapping_op); +} + +} // namespace helpers + +} // namespace raft::cluster::kmeans_balanced diff --git a/cpp/include/raft/cluster/kmeans_balanced_types.hpp b/cpp/include/raft/cluster/kmeans_balanced_types.hpp new file mode 100644 index 0000000000..11b77e288a --- /dev/null +++ b/cpp/include/raft/cluster/kmeans_balanced_types.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include + +namespace raft::cluster::kmeans_balanced { + +/** + * Simple object to specify hyper-parameters to the balanced k-means algorithm. + * + * The following metrics are currently supported in k-means balanced: + * - InnerProduct + * - L2Expanded + * - L2SqrtExpanded + */ +struct kmeans_balanced_params : kmeans_base_params { + /** + * Number of training iterations + */ + uint32_t n_iters = 20; +}; + +} // namespace raft::cluster::kmeans_balanced + +namespace raft::cluster { + +using kmeans_balanced::kmeans_balanced_params; + +} // namespace raft::cluster diff --git a/cpp/include/raft/cluster/kmeans_deprecated.cuh b/cpp/include/raft/cluster/kmeans_deprecated.cuh index a4cac4cb0f..8e0861ada1 100644 --- a/cpp/include/raft/cluster/kmeans_deprecated.cuh +++ b/cpp/include/raft/cluster/kmeans_deprecated.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,7 +46,7 @@ namespace cluster { * @return error flag */ template -int kmeans(handle_t const& handle, +int kmeans(raft::device_resources const& handle, index_type_t n, index_type_t d, index_type_t k, diff --git a/cpp/include/raft/cluster/kmeans_types.hpp b/cpp/include/raft/cluster/kmeans_types.hpp index b34f3320ad..4d956ad7a0 100644 --- a/cpp/include/raft/cluster/kmeans_types.hpp +++ b/cpp/include/raft/cluster/kmeans_types.hpp @@ -18,12 +18,24 @@ #include #include +namespace raft::cluster { + +/** Base structure for parameters that are common to all k-means algorithms */ +struct kmeans_base_params { + /** + * Metric to use for distance computation. The supported metrics can vary per algorithm. + */ + raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; +}; + +} // namespace raft::cluster + namespace raft::cluster::kmeans { /** * Simple object to specify hyper-parameters to the kmeans algorithm. */ -struct KMeansParams { +struct KMeansParams : kmeans_base_params { enum InitMethod { /** @@ -77,11 +89,6 @@ struct KMeansParams { */ raft::random::RngState rng_state{0}; - /** - * Metric to use for distance computation. - */ - raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded; - /** * Number of instance k-means algorithm will be run with different seeds. */ diff --git a/cpp/include/raft/cluster/single_linkage.cuh b/cpp/include/raft/cluster/single_linkage.cuh index 2d74c364b2..91241b853b 100644 --- a/cpp/include/raft/cluster/single_linkage.cuh +++ b/cpp/include/raft/cluster/single_linkage.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ namespace raft::cluster { template -void single_linkage(const raft::handle_t& handle, +void single_linkage(raft::device_resources const& handle, const value_t* X, size_t m, size_t n, @@ -87,7 +87,7 @@ constexpr int DEFAULT_CONST_C = 15; control of k. The algorithm will set `k = log(n) + c` */ template -void single_linkage(const raft::handle_t& handle, +void single_linkage(raft::device_resources const& handle, raft::device_matrix_view X, raft::device_matrix_view dendrogram, raft::device_vector_view labels, diff --git a/cpp/include/raft/comms/comms_test.hpp b/cpp/include/raft/comms/comms_test.hpp index c7e5dd3ab6..c61bb32f79 100644 --- a/cpp/include/raft/comms/comms_test.hpp +++ b/cpp/include/raft/comms/comms_test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include #include -#include +#include namespace raft { namespace comms { @@ -31,7 +31,7 @@ namespace comms { * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allreduce(const handle_t& handle, int root) +bool test_collective_allreduce(raft::device_resources const& handle, int root) { return detail::test_collective_allreduce(handle, root); } @@ -43,7 +43,7 @@ bool test_collective_allreduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_broadcast(const handle_t& handle, int root) +bool test_collective_broadcast(raft::device_resources const& handle, int root) { return detail::test_collective_broadcast(handle, root); } @@ -55,7 +55,7 @@ bool test_collective_broadcast(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reduce(const handle_t& handle, int root) +bool test_collective_reduce(raft::device_resources const& handle, int root) { return detail::test_collective_reduce(handle, root); } @@ -67,7 +67,7 @@ bool test_collective_reduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allgather(const handle_t& handle, int root) +bool test_collective_allgather(raft::device_resources const& handle, int root) { return detail::test_collective_allgather(handle, root); } @@ -79,7 +79,7 @@ bool test_collective_allgather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gather(const handle_t& handle, int root) +bool test_collective_gather(raft::device_resources const& handle, int root) { return detail::test_collective_gather(handle, root); } @@ -91,7 +91,7 @@ bool test_collective_gather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gatherv(const handle_t& handle, int root) +bool test_collective_gatherv(raft::device_resources const& handle, int root) { return detail::test_collective_gatherv(handle, root); } @@ -103,7 +103,7 @@ bool test_collective_gatherv(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reducescatter(const handle_t& handle, int root) +bool test_collective_reducescatter(raft::device_resources const& handle, int root) { return detail::test_collective_reducescatter(handle, root); } @@ -115,7 +115,7 @@ bool test_collective_reducescatter(const handle_t& handle, int root) * initialized comms instance. * @param[in] numTrials number of iterations of all-to-all messaging to perform */ -bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_simple_send_recv(h, numTrials); } @@ -127,7 +127,7 @@ bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_device_send_or_recv(h, numTrials); } @@ -139,7 +139,7 @@ bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_device_sendrecv(h, numTrials); } @@ -151,7 +151,7 @@ bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h, int numTrials) { return detail::test_pointToPoint_device_multicast_sendrecv(h, numTrials); } @@ -163,6 +163,9 @@ bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrial * initialized comms instance. * @param n_colors number of different colors to test */ -bool test_commsplit(const handle_t& h, int n_colors) { return detail::test_commsplit(h, n_colors); } +bool test_commsplit(raft::device_resources const& h, int n_colors) +{ + return detail::test_commsplit(h, n_colors); +} } // namespace comms }; // namespace raft diff --git a/cpp/include/raft/comms/detail/mpi_comms.hpp b/cpp/include/raft/comms/detail/mpi_comms.hpp index 508a9ce717..4062389eea 100644 --- a/cpp/include/raft/comms/detail/mpi_comms.hpp +++ b/cpp/include/raft/comms/detail/mpi_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,8 +28,8 @@ #include #include +#include #include -#include #include #include #include diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 33892597d8..0db27f0a45 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2020-2022, NVIDIA CORPORATION. + * Copyright (c) 2020-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/include/raft/comms/detail/test.hpp b/cpp/include/raft/comms/detail/test.hpp index 6ba4be3886..2b12bf2d2a 100644 --- a/cpp/include/raft/comms/detail/test.hpp +++ b/cpp/include/raft/comms/detail/test.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include @@ -38,7 +38,7 @@ namespace detail { * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allreduce(const handle_t& handle, int root) +bool test_collective_allreduce(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -69,7 +69,7 @@ bool test_collective_allreduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_broadcast(const handle_t& handle, int root) +bool test_collective_broadcast(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -104,7 +104,7 @@ bool test_collective_broadcast(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reduce(const handle_t& handle, int root) +bool test_collective_reduce(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -140,7 +140,7 @@ bool test_collective_reduce(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_allgather(const handle_t& handle, int root) +bool test_collective_allgather(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -177,7 +177,7 @@ bool test_collective_allgather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gather(const handle_t& handle, int root) +bool test_collective_gather(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -214,7 +214,7 @@ bool test_collective_gather(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_gatherv(const handle_t& handle, int root) +bool test_collective_gatherv(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -273,7 +273,7 @@ bool test_collective_gatherv(const handle_t& handle, int root) * initialized comms instance. * @param[in] root the root rank id */ -bool test_collective_reducescatter(const handle_t& handle, int root) +bool test_collective_reducescatter(raft::device_resources const& handle, int root) { comms_t const& communicator = handle.get_comms(); @@ -308,7 +308,7 @@ bool test_collective_reducescatter(const handle_t& handle, int root) * initialized comms instance. * @param[in] numTrials number of iterations of all-to-all messaging to perform */ -bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_simple_send_recv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -373,7 +373,7 @@ bool test_pointToPoint_simple_send_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_send_or_recv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -415,7 +415,7 @@ bool test_pointToPoint_device_send_or_recv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_sendrecv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -461,7 +461,7 @@ bool test_pointToPoint_device_sendrecv(const handle_t& h, int numTrials) * initialized comms instance. * @param numTrials number of iterations of send or receive messaging to perform */ -bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrials) +bool test_pointToPoint_device_multicast_sendrecv(raft::device_resources const& h, int numTrials) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); @@ -520,7 +520,7 @@ bool test_pointToPoint_device_multicast_sendrecv(const handle_t& h, int numTrial * initialized comms instance. * @param n_colors number of different colors to test */ -bool test_commsplit(const handle_t& h, int n_colors) +bool test_commsplit(raft::device_resources const& h, int n_colors) { comms_t const& communicator = h.get_comms(); int const rank = communicator.get_rank(); diff --git a/cpp/include/raft/comms/mpi_comms.hpp b/cpp/include/raft/comms/mpi_comms.hpp index b3ea62efd2..9076176ea6 100644 --- a/cpp/include/raft/comms/mpi_comms.hpp +++ b/cpp/include/raft/comms/mpi_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,7 +40,7 @@ using mpi_comms = detail::mpi_comms; * #include * * MPI_Comm mpi_comm; - * raft::handle_t handle; + * raft::raft::device_resources handle; * * initialize_mpi_comms(&handle, mpi_comm); * ... @@ -55,7 +55,7 @@ using mpi_comms = detail::mpi_comms; * comm.sync_stream(handle.get_stream()); * @endcode */ -inline void initialize_mpi_comms(handle_t* handle, MPI_Comm comm) +inline void initialize_mpi_comms(device_resources* handle, MPI_Comm comm) { auto communicator = std::make_shared( std::unique_ptr(new mpi_comms(comm, false, handle->get_stream()))); diff --git a/cpp/include/raft/comms/std_comms.hpp b/cpp/include/raft/comms/std_comms.hpp index 5e619053da..6370d4a8e6 100644 --- a/cpp/include/raft/comms/std_comms.hpp +++ b/cpp/include/raft/comms/std_comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include @@ -39,7 +39,7 @@ using std_comms = detail::std_comms; * Factory function to construct a RAFT NCCL communicator and inject it into a * RAFT handle. * - * @param handle raft::handle_t for injecting the comms + * @param handle raft::device_resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param num_ranks number of ranks in communicator clique * @param rank rank of local instance @@ -49,7 +49,7 @@ using std_comms = detail::std_comms; * #include * * ncclComm_t nccl_comm; - * raft::handle_t handle; + * raft::raft::device_resources handle; * * build_comms_nccl_only(&handle, nccl_comm, 5, 0); * ... @@ -64,7 +64,7 @@ using std_comms = detail::std_comms; * comm.sync_stream(handle.get_stream()); * @endcode */ -void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks, int rank) +void build_comms_nccl_only(device_resources* handle, ncclComm_t nccl_comm, int num_ranks, int rank) { cudaStream_t stream = handle->get_stream(); @@ -77,7 +77,7 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks * Factory function to construct a RAFT NCCL+UCX and inject it into a RAFT * handle. * - * @param handle raft::handle_t for injecting the comms + * @param handle raft::device_resources for injecting the comms * @param nccl_comm initialized NCCL communicator to use for collectives * @param ucp_worker of local process * Note: This is purposefully left as void* so that the ucp_worker_h @@ -93,7 +93,7 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks * #include * * ncclComm_t nccl_comm; - * raft::handle_t handle; + * raft::raft::device_resources handle; * ucp_worker_h ucp_worker; * ucp_ep_h *ucp_endpoints_arr; * @@ -110,8 +110,12 @@ void build_comms_nccl_only(handle_t* handle, ncclComm_t nccl_comm, int num_ranks * comm.sync_stream(handle.get_stream()); * @endcode */ -void build_comms_nccl_ucx( - handle_t* handle, ncclComm_t nccl_comm, void* ucp_worker, void* eps, int num_ranks, int rank) +void build_comms_nccl_ucx(device_resources* handle, + ncclComm_t nccl_comm, + void* ucp_worker, + void* eps, + int num_ranks, + int rank) { auto eps_sp = std::make_shared(new ucp_ep_h[num_ranks]); diff --git a/cpp/include/raft/core/comms.hpp b/cpp/include/raft/core/comms.hpp index 35ab6680de..463c17f2f6 100644 --- a/cpp/include/raft/core/comms.hpp +++ b/cpp/include/raft/core/comms.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/include/raft/core/detail/device_mdarray.hpp b/cpp/include/raft/core/detail/device_mdarray.hpp index ad6831794e..31dfaba70a 100644 --- a/cpp/include/raft/core/detail/device_mdarray.hpp +++ b/cpp/include/raft/core/detail/device_mdarray.hpp @@ -6,7 +6,7 @@ */ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ */ #pragma once #include -#include +#include #include #include // dynamic_extent diff --git a/cpp/include/raft/core/device_mdarray.hpp b/cpp/include/raft/core/device_mdarray.hpp index 693e50a506..03cb09eecb 100644 --- a/cpp/include/raft/core/device_mdarray.hpp +++ b/cpp/include/raft/core/device_mdarray.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,7 +72,7 @@ using device_matrix = device_mdarray, Layo * @tparam ElementType the data type of the matrix elements * @tparam IndexType the index type of the extents * @tparam LayoutPolicy policy for strides and layout ordering - * @param handle raft::handle_t + * @param handle raft::device_resources * @param exts dimensionality of the array (series of integers) * @return raft::device_mdarray */ @@ -80,7 +80,7 @@ template -auto make_device_mdarray(const raft::handle_t& handle, extents exts) +auto make_device_mdarray(raft::device_resources const& handle, extents exts) { using mdarray_t = device_mdarray; @@ -95,7 +95,7 @@ auto make_device_mdarray(const raft::handle_t& handle, extents -auto make_device_mdarray(const raft::handle_t& handle, +auto make_device_mdarray(raft::device_resources const& handle, rmm::mr::device_memory_resource* mr, extents exts) { @@ -130,7 +130,7 @@ auto make_device_mdarray(const raft::handle_t& handle, template -auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexType n_cols) +auto make_device_matrix(raft::device_resources const& handle, IndexType n_rows, IndexType n_cols) { return make_device_mdarray( handle.get_stream(), make_extents(n_rows, n_cols)); @@ -146,7 +146,7 @@ auto make_device_matrix(raft::handle_t const& handle, IndexType n_rows, IndexTyp * @return raft::device_scalar */ template -auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) +auto make_device_scalar(raft::device_resources const& handle, ElementType const& v) { scalar_extent extents; using policy_t = typename device_scalar::container_policy_type; @@ -168,7 +168,7 @@ auto make_device_scalar(raft::handle_t const& handle, ElementType const& v) template -auto make_device_vector(raft::handle_t const& handle, IndexType n) +auto make_device_vector(raft::device_resources const& handle, IndexType n) { return make_device_mdarray(handle.get_stream(), make_extents(n)); diff --git a/cpp/include/raft/core/device_mdspan.hpp b/cpp/include/raft/core/device_mdspan.hpp index f64f15d0d5..f72ae36d64 100644 --- a/cpp/include/raft/core/device_mdspan.hpp +++ b/cpp/include/raft/core/device_mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -197,7 +197,9 @@ auto make_device_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexTy detail::alignment::value>::data_handle_type; static_assert(std::is_same>::value || std::is_same>::value); - assert(ptr == alignTo(ptr, detail::alignment::value)); + assert(reinterpret_cast(ptr) == + std::experimental::details::alignTo(reinterpret_cast(ptr), + detail::alignment::value)); data_handle_type aligned_pointer = ptr; diff --git a/cpp/include/raft/core/device_resources.hpp b/cpp/include/raft/core/device_resources.hpp new file mode 100644 index 0000000000..68c56dc9b6 --- /dev/null +++ b/cpp/include/raft/core/device_resources.hpp @@ -0,0 +1,259 @@ +/* + * Copyright (c) 2019-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RAFT_DEVICE_RESOURCES +#define __RAFT_DEVICE_RESOURCES + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Main resource container object that stores all necessary resources + * used for calling necessary device functions, cuda kernels and/or libraries + */ +class device_resources : public resources { + public: + device_resources(const device_resources& handle, + rmm::mr::device_memory_resource* workspace_resource) + : resources{handle} + { + // replace the resource factory for the workspace_resources + resources::add_resource_factory( + std::make_shared(workspace_resource)); + } + + device_resources(const device_resources& handle) : resources{handle} {} + + device_resources(device_resources&&) = delete; + device_resources& operator=(device_resources&&) = delete; + + /** + * @brief Construct a resources instance with a stream view and stream pool + * + * @param[in] stream_view the default stream (which has the default per-thread stream if + * unspecified) + * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + * @param[in] workspace_resource an optional resource used by some functions for allocating + * temporary workspaces. + */ + device_resources(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : resources{} + { + resources::add_resource_factory(std::make_shared()); + resources::add_resource_factory( + std::make_shared(stream_view)); + resources::add_resource_factory( + std::make_shared(stream_pool)); + resources::add_resource_factory( + std::make_shared(workspace_resource)); + } + + /** Destroys all held-up resources */ + virtual ~device_resources() {} + + int get_device() const { return resource::get_device_id(*this); } + + cublasHandle_t get_cublas_handle() const { return resource::get_cublas_handle(*this); } + + cusolverDnHandle_t get_cusolver_dn_handle() const + { + return resource::get_cusolver_dn_handle(*this); + } + + cusolverSpHandle_t get_cusolver_sp_handle() const + { + return resource::get_cusolver_sp_handle(*this); + } + + cusparseHandle_t get_cusparse_handle() const { return resource::get_cusparse_handle(*this); } + + rmm::exec_policy& get_thrust_policy() const { return resource::get_thrust_policy(*this); } + + /** + * @brief synchronize a stream on the current container + */ + void sync_stream(rmm::cuda_stream_view stream) const { resource::sync_stream(*this, stream); } + + /** + * @brief synchronize main stream on the current container + */ + void sync_stream() const { resource::sync_stream(*this); } + + /** + * @brief returns main stream on the current container + */ + rmm::cuda_stream_view get_stream() const { return resource::get_cuda_stream(*this); } + + /** + * @brief returns whether stream pool was initialized on the current container + */ + + bool is_stream_pool_initialized() const { return resource::is_stream_pool_initialized(*this); } + + /** + * @brief returns stream pool on the current container + */ + const rmm::cuda_stream_pool& get_stream_pool() const + { + return resource::get_cuda_stream_pool(*this); + } + + std::size_t get_stream_pool_size() const { return resource::get_stream_pool_size(*this); } + + /** + * @brief return stream from pool + */ + rmm::cuda_stream_view get_stream_from_stream_pool() const + { + return resource::get_stream_from_stream_pool(*this); + } + + /** + * @brief return stream from pool at index + */ + rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const + { + return resource::get_stream_from_stream_pool(*this, stream_idx); + } + + /** + * @brief return stream from pool if size > 0, else main stream on current container + */ + rmm::cuda_stream_view get_next_usable_stream() const + { + return resource::get_next_usable_stream(*this); + } + + /** + * @brief return stream from pool at index if size > 0, else main stream on current container + * + * @param[in] stream_idx the required index of the stream in the stream pool if available + */ + rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const + { + return resource::get_next_usable_stream(*this, stream_idx); + } + + /** + * @brief synchronize the stream pool on the current container + */ + void sync_stream_pool() const { return resource::sync_stream_pool(*this); } + + /** + * @brief synchronize subset of stream pool + * + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ + void sync_stream_pool(const std::vector stream_indices) const + { + return resource::sync_stream_pool(*this, stream_indices); + } + + /** + * @brief ask stream pool to wait on last event in main stream + */ + void wait_stream_pool_on_stream() const { return resource::wait_stream_pool_on_stream(*this); } + + void set_comms(std::shared_ptr communicator) + { + resource::set_comms(*this, communicator); + } + + const comms::comms_t& get_comms() const { return resource::get_comms(*this); } + + void set_subcomm(std::string key, std::shared_ptr subcomm) + { + resource::set_subcomm(*this, key, subcomm); + } + + const comms::comms_t& get_subcomm(std::string key) const + { + return resource::get_subcomm(*this, key); + } + + rmm::mr::device_memory_resource* get_workspace_resource() const + { + return resource::get_workspace_resource(*this); + } + + bool comms_initialized() const { return resource::comms_initialized(*this); } + + const cudaDeviceProp& get_device_properties() const + { + return resource::get_device_properties(*this); + } +}; // class device_resources + +/** + * @brief RAII approach to synchronizing across all streams in the current container + */ +class stream_syncer { + public: + explicit stream_syncer(const device_resources& handle) : handle_(handle) + { + handle_.sync_stream(); + } + ~stream_syncer() + { + handle_.wait_stream_pool_on_stream(); + handle_.sync_stream_pool(); + } + + stream_syncer(const stream_syncer& other) = delete; + stream_syncer& operator=(const stream_syncer& other) = delete; + + private: + const device_resources& handle_; +}; // class stream_syncer + +} // namespace raft + +#endif \ No newline at end of file diff --git a/cpp/include/raft/core/handle.hpp b/cpp/include/raft/core/handle.hpp index 08cb812bb7..02efebec9e 100644 --- a/cpp/include/raft/core/handle.hpp +++ b/cpp/include/raft/core/handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -14,326 +14,52 @@ * limitations under the License. */ -#ifndef __RAFT_RT_HANDLE -#define __RAFT_RT_HANDLE - #pragma once -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include - -///@todo: enable once we have migrated cuml-comms layer too -//#include - -#include - -#include -#include -#include -#include -#include -#include -#include +#include namespace raft { /** - * @brief Main handle object that stores all necessary context used for calling - * necessary cuda kernels and/or libraries + * raft::handle_t is being kept around for backwards + * compatibility and will be removed in a future version. + * + * Extending the `raft::handle_t` instead of `using` to + * minimize needed changes downstream + * (e.g. existing forward declarations, etc...) + * + * Use of `raft::resources` or `raft::handle_t` is preferred. */ -class handle_t { +class handle_t : public raft::device_resources { public: - // delete copy/move constructors and assignment operators as - // copying and moving underlying resources is unsafe - handle_t(const handle_t&) = delete; - handle_t& operator=(const handle_t&) = delete; - handle_t(handle_t&&) = delete; + handle_t(const handle_t& handle, rmm::mr::device_memory_resource* workspace_resource) + : device_resources(handle, workspace_resource) + { + } + + handle_t(const handle_t& handle) : device_resources{handle} {} + + handle_t(handle_t&&) = delete; handle_t& operator=(handle_t&&) = delete; /** - * @brief Construct a handle with a stream view and stream pool + * @brief Construct a resources instance with a stream view and stream pool * * @param[in] stream_view the default stream (which has the default per-thread stream if * unspecified) * @param[in] stream_pool the stream pool used (which has default of nullptr if unspecified) + * @param[in] workspace_resource an optional resource used by some functions for allocating + * temporary workspaces. */ - handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, - std::shared_ptr stream_pool = {nullptr}) - : dev_id_([]() -> int { - int cur_dev = -1; - RAFT_CUDA_TRY(cudaGetDevice(&cur_dev)); - return cur_dev; - }()), - stream_view_{stream_view}, - stream_pool_{stream_pool} + handle_t(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread, + std::shared_ptr stream_pool = {nullptr}, + rmm::mr::device_memory_resource* workspace_resource = nullptr) + : device_resources{stream_view, stream_pool, workspace_resource} { - create_resources(); } /** Destroys all held-up resources */ - virtual ~handle_t() { destroy_resources(); } - - int get_device() const { return dev_id_; } - - cublasHandle_t get_cublas_handle() const - { - std::lock_guard _(mutex_); - if (!cublas_initialized_) { - RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_handle_)); - RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_handle_, stream_view_)); - cublas_initialized_ = true; - } - return cublas_handle_; - } - - cusolverDnHandle_t get_cusolver_dn_handle() const - { - std::lock_guard _(mutex_); - if (!cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_dn_handle_)); - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_dn_handle_, stream_view_)); - cusolver_dn_initialized_ = true; - } - return cusolver_dn_handle_; - } - - cusolverSpHandle_t get_cusolver_sp_handle() const - { - std::lock_guard _(mutex_); - if (!cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_sp_handle_)); - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_sp_handle_, stream_view_)); - cusolver_sp_initialized_ = true; - } - return cusolver_sp_handle_; - } - - cusparseHandle_t get_cusparse_handle() const - { - std::lock_guard _(mutex_); - if (!cusparse_initialized_) { - RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_handle_)); - RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_handle_, stream_view_)); - cusparse_initialized_ = true; - } - return cusparse_handle_; - } - - rmm::exec_policy& get_thrust_policy() const { return *thrust_policy_; } - - /** - * @brief synchronize a stream on the handle - */ - void sync_stream(rmm::cuda_stream_view stream) const { interruptible::synchronize(stream); } - - /** - * @brief synchronize main stream on the handle - */ - void sync_stream() const { sync_stream(stream_view_); } - - /** - * @brief returns main stream on the handle - */ - rmm::cuda_stream_view get_stream() const { return stream_view_; } - - /** - * @brief returns whether stream pool was initialized on the handle - */ - - bool is_stream_pool_initialized() const { return stream_pool_.get() != nullptr; } - - /** - * @brief returns stream pool on the handle - */ - const rmm::cuda_stream_pool& get_stream_pool() const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return *stream_pool_; - } - - std::size_t get_stream_pool_size() const - { - return is_stream_pool_initialized() ? stream_pool_->get_pool_size() : 0; - } - - /** - * @brief return stream from pool - */ - rmm::cuda_stream_view get_stream_from_stream_pool() const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return stream_pool_->get_stream(); - } - - /** - * @brief return stream from pool at index - */ - rmm::cuda_stream_view get_stream_from_stream_pool(std::size_t stream_idx) const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - return stream_pool_->get_stream(stream_idx); - } - - /** - * @brief return stream from pool if size > 0, else main stream on handle - */ - rmm::cuda_stream_view get_next_usable_stream() const - { - return is_stream_pool_initialized() ? get_stream_from_stream_pool() : stream_view_; - } - - /** - * @brief return stream from pool at index if size > 0, else main stream on handle - * - * @param[in] stream_idx the required index of the stream in the stream pool if available - */ - rmm::cuda_stream_view get_next_usable_stream(std::size_t stream_idx) const - { - return is_stream_pool_initialized() ? get_stream_from_stream_pool(stream_idx) : stream_view_; - } - - /** - * @brief synchronize the stream pool on the handle - */ - void sync_stream_pool() const - { - for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - sync_stream(stream_pool_->get_stream(i)); - } - } - - /** - * @brief synchronize subset of stream pool - * - * @param[in] stream_indices the indices of the streams in the stream pool to synchronize - */ - void sync_stream_pool(const std::vector stream_indices) const - { - RAFT_EXPECTS(stream_pool_, "ERROR: rmm::cuda_stream_pool was not initialized"); - for (const auto& stream_index : stream_indices) { - sync_stream(stream_pool_->get_stream(stream_index)); - } - } - - /** - * @brief ask stream pool to wait on last event in main stream - */ - void wait_stream_pool_on_stream() const - { - RAFT_CUDA_TRY(cudaEventRecord(event_, stream_view_)); - for (std::size_t i = 0; i < get_stream_pool_size(); i++) { - RAFT_CUDA_TRY(cudaStreamWaitEvent(stream_pool_->get_stream(i), event_, 0)); - } - } - - void set_comms(std::shared_ptr communicator) { communicator_ = communicator; } - - const comms::comms_t& get_comms() const - { - RAFT_EXPECTS(this->comms_initialized(), "ERROR: Communicator was not initialized\n"); - return *communicator_; - } - - void set_subcomm(std::string key, std::shared_ptr subcomm) - { - subcomms_[key] = subcomm; - } - - const comms::comms_t& get_subcomm(std::string key) const - { - RAFT_EXPECTS( - subcomms_.find(key) != subcomms_.end(), "%s was not found in subcommunicators.", key.c_str()); - - auto subcomm = subcomms_.at(key); - - RAFT_EXPECTS(nullptr != subcomm.get(), "ERROR: Subcommunicator was not initialized"); - - return *subcomm; - } - - bool comms_initialized() const { return (nullptr != communicator_.get()); } - - const cudaDeviceProp& get_device_properties() const - { - std::lock_guard _(mutex_); - if (!device_prop_initialized_) { - RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id_)); - device_prop_initialized_ = true; - } - return prop_; - } - - private: - std::shared_ptr communicator_; - std::unordered_map> subcomms_; - - const int dev_id_; - mutable cublasHandle_t cublas_handle_; - mutable bool cublas_initialized_{false}; - mutable cusolverDnHandle_t cusolver_dn_handle_; - mutable bool cusolver_dn_initialized_{false}; - mutable cusolverSpHandle_t cusolver_sp_handle_; - mutable bool cusolver_sp_initialized_{false}; - mutable cusparseHandle_t cusparse_handle_; - mutable bool cusparse_initialized_{false}; - std::unique_ptr thrust_policy_{nullptr}; - rmm::cuda_stream_view stream_view_{rmm::cuda_stream_per_thread}; - std::shared_ptr stream_pool_{nullptr}; - cudaEvent_t event_; - mutable cudaDeviceProp prop_; - mutable bool device_prop_initialized_{false}; - mutable std::mutex mutex_; - - void create_resources() - { - thrust_policy_ = std::make_unique(stream_view_); - - RAFT_CUDA_TRY(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); - } - - void destroy_resources() - { - if (cusparse_initialized_) { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_handle_)); } - if (cusolver_dn_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_dn_handle_)); - } - if (cusolver_sp_initialized_) { - RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_sp_handle_)); - } - if (cublas_initialized_) { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_handle_)); } - RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); - } -}; // class handle_t - -/** - * @brief RAII approach to synchronizing across all streams in the handle - */ -class stream_syncer { - public: - explicit stream_syncer(const handle_t& handle) : handle_(handle) { handle_.sync_stream(); } - ~stream_syncer() - { - handle_.wait_stream_pool_on_stream(); - handle_.sync_stream_pool(); - } - - stream_syncer(const stream_syncer& other) = delete; - stream_syncer& operator=(const stream_syncer& other) = delete; - - private: - const handle_t& handle_; -}; // class stream_syncer - -} // namespace raft + ~handle_t() override {} +}; -#endif \ No newline at end of file +} // end NAMESPACE raft diff --git a/cpp/include/raft/core/host_mdspan.hpp b/cpp/include/raft/core/host_mdspan.hpp index 1a0ea6432f..a6cdec7a84 100644 --- a/cpp/include/raft/core/host_mdspan.hpp +++ b/cpp/include/raft/core/host_mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -144,7 +144,9 @@ auto make_host_aligned_matrix_view(ElementType* ptr, IndexType n_rows, IndexType static_assert(std::is_same>::value || std::is_same>::value); - assert(ptr == alignTo(ptr, detail::alignment::value)); + assert(reinterpret_cast(ptr) == + std::experimental::details::alignTo(reinterpret_cast(ptr), + detail::alignment::value)); data_handle_type aligned_pointer = ptr; matrix_extent extents{n_rows, n_cols}; diff --git a/cpp/include/raft/core/kvp.hpp b/cpp/include/raft/core/kvp.hpp index f6ea841dc4..8d3321eb77 100644 --- a/cpp/include/raft/core/kvp.hpp +++ b/cpp/include/raft/core/kvp.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ #ifdef _RAFT_HAS_CUDA #include +#include #endif namespace raft { /** @@ -58,5 +59,27 @@ struct KeyValuePair { { return (value != b.value) || (key != b.key); } + + RAFT_INLINE_FUNCTION bool operator<(const KeyValuePair<_Key, _Value>& b) const + { + return (key < b.key) || ((key == b.key) && value < b.value); + } + + RAFT_INLINE_FUNCTION bool operator>(const KeyValuePair<_Key, _Value>& b) const + { + return (key > b.key) || ((key == b.key) && value > b.value); + } }; + +#ifdef _RAFT_HAS_CUDA +template +RAFT_INLINE_FUNCTION KeyValuePair<_Key, _Value> shfl_xor(const KeyValuePair<_Key, _Value>& input, + int laneMask, + int width = WarpSize, + uint32_t mask = 0xffffffffu) +{ + return KeyValuePair<_Key, _Value>(shfl_xor(input.key, laneMask, width, mask), + shfl_xor(input.value, laneMask, width, mask)); +} +#endif } // end namespace raft diff --git a/cpp/include/raft/core/math.hpp b/cpp/include/raft/core/math.hpp new file mode 100644 index 0000000000..c5f08b84b7 --- /dev/null +++ b/cpp/include/raft/core/math.hpp @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include + +namespace raft { + +/** + * @defgroup Absolute Absolute value + * @{ + */ +template +RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t || std::is_same_v || + std::is_same_v || std::is_same_v || + std::is_same_v, + T> +{ +#ifdef __CUDA_ARCH__ + return ::abs(x); +#else + return std::abs(x); +#endif +} +template +constexpr RAFT_INLINE_FUNCTION auto abs(T x) + -> std::enable_if_t && !std::is_same_v && + !std::is_same_v && !std::is_same_v && + !std::is_same_v, + T> +{ + return x < T{0} ? -x : x; +} +/** @} */ + +/** + * @defgroup Trigonometry Trigonometry functions + * @{ + */ +/** Inverse cosine */ +template +RAFT_INLINE_FUNCTION auto acos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::acos(x); +#else + return std::acos(x); +#endif +} + +/** Inverse sine */ +template +RAFT_INLINE_FUNCTION auto asin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::asin(x); +#else + return std::asin(x); +#endif +} + +/** Inverse hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto atanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::atanh(x); +#else + return std::atanh(x); +#endif +} + +/** Cosine */ +template +RAFT_INLINE_FUNCTION auto cos(T x) +{ +#ifdef __CUDA_ARCH__ + return ::cos(x); +#else + return std::cos(x); +#endif +} + +/** Sine */ +template +RAFT_INLINE_FUNCTION auto sin(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sin(x); +#else + return std::sin(x); +#endif +} + +/** Sine and cosine */ +template +RAFT_INLINE_FUNCTION std::enable_if_t || std::is_same_v> sincos( + const T& x, T* s, T* c) +{ +#ifdef __CUDA_ARCH__ + ::sincos(x, s, c); +#else + *s = std::sin(x); + *c = std::cos(x); +#endif +} + +/** Hyperbolic tangent */ +template +RAFT_INLINE_FUNCTION auto tanh(T x) +{ +#ifdef __CUDA_ARCH__ + return ::tanh(x); +#else + return std::tanh(x); +#endif +} +/** @} */ + +/** + * @defgroup Exponential Exponential and logarithm + * @{ + */ +/** Exponential function */ +template +RAFT_INLINE_FUNCTION auto exp(T x) +{ +#ifdef __CUDA_ARCH__ + return ::exp(x); +#else + return std::exp(x); +#endif +} + +/** Natural logarithm */ +template +RAFT_INLINE_FUNCTION auto log(T x) +{ +#ifdef __CUDA_ARCH__ + return ::log(x); +#else + return std::log(x); +#endif +} +/** @} */ + +/** + * @defgroup Maximum Maximum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::max, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the max of + * -1 and 1u is 4294967295u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::max(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native max overload for these types. Both argument types must be the same to use " + "the generic max. Please cast appropriately."); + return (x < y) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::max(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::max(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::max requires that both argument types be the same. Please cast appropriately."); + return std::max(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto max(const T1& x, const T2& y, Args&&... args) +{ + return raft::max(x, raft::max(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto max(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Minimum Minimum of two or more values. + * + * The CUDA Math API has overloads for all combinations of float/double. We provide similar + * functionality while wrapping around std::min, which only supports arguments of the same type. + * However, though the CUDA Math API supports combinations of unsigned and signed integers, this is + * very error-prone so we do not support that and require the user to cast instead. (e.g the min of + * -1 and 1u is 1u...) + * + * When no overload matches, we provide a generic implementation but require that both types be the + * same (and that the less-than operator be defined). + * @{ + */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y) +{ +#ifdef __CUDA_ARCH__ + // Combinations of types supported by the CUDA Math API + if constexpr ((std::is_integral_v && std::is_integral_v && std::is_same_v) || + ((std::is_same_v || std::is_same_v)&&( + std::is_same_v || std::is_same_v))) { + return ::min(x, y); + } + // Else, check that the types are the same and provide a generic implementation + else { + static_assert( + std::is_same_v, + "No native min overload for these types. Both argument types must be the same to use " + "the generic min. Please cast appropriately."); + return (y < x) ? y : x; + } +#else + if constexpr (std::is_same_v && std::is_same_v) { + return std::min(static_cast(x), y); + } else if constexpr (std::is_same_v && std::is_same_v) { + return std::min(x, static_cast(y)); + } else { + static_assert( + std::is_same_v, + "std::min requires that both argument types be the same. Please cast appropriately."); + return std::min(x, y); + } +#endif +} + +/** Many-argument overload to avoid verbose nested calls or use with variadic arguments */ +template +RAFT_INLINE_FUNCTION auto min(const T1& x, const T2& y, Args&&... args) +{ + return raft::min(x, raft::min(y, std::forward(args)...)); +} + +/** One-argument overload for convenience when using with variadic arguments */ +template +constexpr RAFT_INLINE_FUNCTION auto min(const T& x) +{ + return x; +} +/** @} */ + +/** + * @defgroup Power Power and root functions + * @{ + */ +/** Power */ +template +RAFT_INLINE_FUNCTION auto pow(T1 x, T2 y) +{ +#ifdef __CUDA_ARCH__ + return ::pow(x, y); +#else + return std::pow(x, y); +#endif +} + +/** Square root */ +template +RAFT_INLINE_FUNCTION auto sqrt(T x) +{ +#ifdef __CUDA_ARCH__ + return ::sqrt(x); +#else + return std::sqrt(x); +#endif +} +/** @} */ + +/** Sign */ +template +RAFT_INLINE_FUNCTION auto sgn(T val) -> int +{ + return (T(0) < val) - (val < T(0)); +} + +} // namespace raft diff --git a/cpp/include/raft/core/mdspan.hpp b/cpp/include/raft/core/mdspan.hpp index 786ce69f89..f805d20064 100644 --- a/cpp/include/raft/core/mdspan.hpp +++ b/cpp/include/raft/core/mdspan.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -304,4 +304,52 @@ RAFT_INLINE_FUNCTION auto unravel_index(Idx idx, } } +/** + * @brief Const accessor specialization for default_accessor + * + * @tparam ElementType + * @param a + * @return std::experimental::default_accessor> + */ +template +std::experimental::default_accessor> accessor_of_const( + std::experimental::default_accessor a) +{ + return {a}; +} + +/** + * @brief Const accessor specialization for host_device_accessor + * + * @tparam ElementType the data type of the mdspan elements + * @tparam MemType the type of memory where the elements are stored. + * @param a host_device_accessor + * @return host_device_accessor>, + * MemType> + */ +template +host_device_accessor>, MemType> +accessor_of_const(host_device_accessor, MemType> a) +{ + return {a}; +} + +/** + * @brief Create a copy of the given mdspan with const element type + * + * @tparam ElementType the const-qualified data type of the mdspan elements + * @tparam Extents raft::extents for dimensions + * @tparam Layout policy for strides and layout ordering + * @tparam Accessor Accessor policy for the input and output + * @param mds raft::mdspan object + * @return raft::mdspan + */ +template +auto make_const_mdspan(mdspan mds) +{ + auto acc_c = accessor_of_const(mds.accessor()); + return mdspan, Extents, Layout, decltype(acc_c)>{ + mds.data_handle(), mds.mapping(), acc_c}; +} + } // namespace raft diff --git a/cpp/include/raft/core/operators.hpp b/cpp/include/raft/core/operators.hpp index de521cc945..7acc907c49 100644 --- a/cpp/include/raft/core/operators.hpp +++ b/cpp/include/raft/core/operators.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,6 +23,7 @@ #include #include +#include namespace raft { @@ -40,6 +41,14 @@ struct identity_op { } }; +struct void_op { + template + constexpr RAFT_INLINE_FUNCTION void operator()(UnusedArgs...) const + { + return; + } +}; + template struct cast_op { template @@ -67,9 +76,9 @@ struct value_op { struct sqrt_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { - return std::sqrt(in); + return raft::sqrt(in); } }; @@ -83,9 +92,9 @@ struct nz_op { struct abs_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const + RAFT_INLINE_FUNCTION auto operator()(const Type& in, UnusedArgs...) const { - return std::abs(in); + return raft::abs(in); } }; @@ -130,37 +139,43 @@ struct div_op { }; struct div_checkzero_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { - if (b == Type{0}) { return Type{0}; } + if (b == T2{0}) { return T1{0} / T2{1}; } return a / b; } }; struct pow_op { template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + { + return raft::pow(a, b); + } +}; + +struct mod_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { - return std::pow(a, b); + return a % b; } }; struct min_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - if (a > b) { return b; } - return a; + return raft::min(std::forward(args)...); } }; struct max_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + RAFT_INLINE_FUNCTION auto operator()(Args&&... args) const { - if (b > a) { return b; } - return a; + return raft::max(std::forward(args)...); } }; @@ -182,17 +197,49 @@ struct argmax_op { } }; +struct greater_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a > b; + } +}; + +struct less_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a < b; + } +}; + +struct greater_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a >= b; + } +}; + +struct less_or_equal_op { + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const + { + return a <= b; + } +}; + struct equal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a == b; } }; struct notequal_op { - template - constexpr RAFT_INLINE_FUNCTION auto operator()(const Type& a, const Type& b) const + template + constexpr RAFT_INLINE_FUNCTION auto operator()(const T1& a, const T2& b) const { return a != b; } @@ -263,6 +310,15 @@ using div_checkzero_const_op = plug_const_op; template using pow_const_op = plug_const_op; +template +using mod_const_op = plug_const_op; + +template +using mod_const_op = plug_const_op; + +template +using equal_const_op = plug_const_op; + /** * @brief Constructs an operator by composing a chain of operators. * diff --git a/cpp/include/raft/core/resource/comms.hpp b/cpp/include/raft/core/resource/comms.hpp new file mode 100644 index 0000000000..73de166c14 --- /dev/null +++ b/cpp/include/raft/core/resource/comms.hpp @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class comms_resource : public resource { + public: + comms_resource(std::shared_ptr comnumicator) : communicator_(comnumicator) {} + + void* get_resource() override { return &communicator_; } + + ~comms_resource() override {} + + private: + std::shared_ptr communicator_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class comms_resource_factory : public resource_factory { + public: + comms_resource_factory(std::shared_ptr communicator) : communicator_(communicator) + { + } + + resource_type get_resource_type() override { return resource_type::COMMUNICATOR; } + + resource* make_resource() override { return new comms_resource(communicator_); } + + private: + std::shared_ptr communicator_; +}; + +/** + * @defgroup resource_comms Comms resource functions + * @{ + */ + +inline bool comms_initialized(resources const& res) +{ + return res.has_resource_factory(resource_type::COMMUNICATOR); +} + +inline comms::comms_t const& get_comms(resources const& res) +{ + RAFT_EXPECTS(comms_initialized(res), "ERROR: Communicator was not initialized\n"); + return *(*res.get_resource>(resource_type::COMMUNICATOR)); +} + +inline void set_comms(resources const& res, std::shared_ptr communicator) +{ + res.add_resource_factory(std::make_shared(communicator)); +} + +/** + * @} + */ +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cublas_handle.hpp b/cpp/include/raft/core/resource/cublas_handle.hpp new file mode 100644 index 0000000000..710fcc7e60 --- /dev/null +++ b/cpp/include/raft/core/resource/cublas_handle.hpp @@ -0,0 +1,81 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +class cublas_resource : public resource { + public: + cublas_resource(rmm::cuda_stream_view stream) + { + RAFT_CUBLAS_TRY_NO_THROW(cublasCreate(&cublas_res)); + RAFT_CUBLAS_TRY_NO_THROW(cublasSetStream(cublas_res, stream)); + } + + ~cublas_resource() override { RAFT_CUBLAS_TRY_NO_THROW(cublasDestroy(cublas_res)); } + + void* get_resource() override { return &cublas_res; } + + private: + cublasHandle_t cublas_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cublas_resource_factory : public resource_factory { + public: + cublas_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUBLAS_HANDLE; } + resource* make_resource() override { return new cublas_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * @defgroup resource_cublas cuBLAS handle resource functions + * @{ + */ + +/** + * Load a cublasres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cublas handle + */ +inline cublasHandle_t get_cublas_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUBLAS_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUBLAS_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cuda_event.hpp b/cpp/include/raft/core/resource/cuda_event.hpp new file mode 100644 index 0000000000..4859d95ee9 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_event.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::resource { + +class cuda_event_resource : public resource { + public: + cuda_event_resource() + { + RAFT_CUDA_TRY_NO_THROW(cudaEventCreateWithFlags(&event_, cudaEventDisableTiming)); + } + void* get_resource() override { return &event_; } + + ~cuda_event_resource() override { RAFT_CUDA_TRY_NO_THROW(cudaEventDestroy(event_)); } + + private: + cudaEvent_t event_; +}; +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cuda_stream.hpp b/cpp/include/raft/core/resource/cuda_stream.hpp new file mode 100644 index 0000000000..318252199e --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_stream.hpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft::resource { +class cuda_stream_resource : public resource { + public: + cuda_stream_resource(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread) + : stream(stream_view) + { + } + void* get_resource() override { return &stream; } + + ~cuda_stream_resource() override {} + + private: + rmm::cuda_stream_view stream; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class cuda_stream_resource_factory : public resource_factory { + public: + cuda_stream_resource_factory(rmm::cuda_stream_view stream_view = rmm::cuda_stream_per_thread) + : stream(stream_view) + { + } + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_VIEW; } + resource* make_resource() override { return new cuda_stream_resource(stream); } + + private: + rmm::cuda_stream_view stream; +}; + +/** + * @defgroup resource_cuda_stream CUDA stream resource functions + * @{ + */ +/** + * Load a rmm::cuda_stream_view from a resources instance (and populate it on the res + * if needed). + * @param res raft res object for managing resources + * @return + */ +inline rmm::cuda_stream_view get_cuda_stream(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_VIEW)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::CUDA_STREAM_VIEW); +}; + +/** + * Load a rmm::cuda_stream_view from a resources instance (and populate it on the res + * if needed). + * @param[in] res raft resources object for managing resources + * @param[in] stream_view cuda stream view + */ +inline void set_cuda_stream(resources const& res, rmm::cuda_stream_view stream_view) +{ + res.add_resource_factory(std::make_shared(stream_view)); +}; + +/** + * @brief synchronize a specific stream + * + * @param[in] res the raft resources object + * @param[in] stream stream to synchronize + */ +inline void sync_stream(const resources& res, rmm::cuda_stream_view stream) +{ + interruptible::synchronize(stream); +} + +/** + * @brief synchronize main stream on the resources instance + */ +inline void sync_stream(const resources& res) { sync_stream(res, get_cuda_stream(res)); } + +/** + * @} + */ + +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/cuda_stream_pool.hpp b/cpp/include/raft/core/resource/cuda_stream_pool.hpp new file mode 100644 index 0000000000..dbce75b3a5 --- /dev/null +++ b/cpp/include/raft/core/resource/cuda_stream_pool.hpp @@ -0,0 +1,187 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +#include +#include +#include + +namespace raft::resource { + +class cuda_stream_pool_resource : public resource { + public: + cuda_stream_pool_resource(std::shared_ptr stream_pool) + : stream_pool_(stream_pool) + { + } + + ~cuda_stream_pool_resource() override {} + void* get_resource() override { return &stream_pool_; } + + private: + std::shared_ptr stream_pool_{nullptr}; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cuda_stream_pool_resource_factory : public resource_factory { + public: + cuda_stream_pool_resource_factory(std::shared_ptr stream_pool = {nullptr}) + : stream_pool_(stream_pool) + { + } + + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_POOL; } + resource* make_resource() override { return new cuda_stream_pool_resource(stream_pool_); } + + private: + std::shared_ptr stream_pool_{nullptr}; +}; + +inline bool is_stream_pool_initialized(const resources& res) +{ + return *res.get_resource>( + resource_type::CUDA_STREAM_POOL) != nullptr; +} + +/** + * @defgroup resource_stream_pool CUDA Stream pool resource functions + * @{ + */ + +/** + * Load a cuda_stream_pool, and create a new one if it doesn't already exist + * @param res raft res object for managing resources + * @return + */ +inline const rmm::cuda_stream_pool& get_cuda_stream_pool(const resources& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_POOL)) { + res.add_resource_factory(std::make_shared()); + } + return *( + *res.get_resource>(resource_type::CUDA_STREAM_POOL)); +}; + +/** + * Explicitly set a stream pool on the current res. Note that this will overwrite + * an existing stream pool on the res. + * @param res + * @param stream_pool + */ +inline void set_cuda_stream_pool(const resources& res, + std::shared_ptr stream_pool) +{ + res.add_resource_factory(std::make_shared(stream_pool)); +}; + +inline std::size_t get_stream_pool_size(const resources& res) +{ + return is_stream_pool_initialized(res) ? get_cuda_stream_pool(res).get_pool_size() : 0; +} + +/** + * @brief return stream from pool + */ +inline rmm::cuda_stream_view get_stream_from_stream_pool(const resources& res) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + return get_cuda_stream_pool(res).get_stream(); +} + +/** + * @brief return stream from pool at index + */ +inline rmm::cuda_stream_view get_stream_from_stream_pool(const resources& res, + std::size_t stream_idx) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + return get_cuda_stream_pool(res).get_stream(stream_idx); +} + +/** + * @brief return stream from pool if size > 0, else main stream on res + */ +inline rmm::cuda_stream_view get_next_usable_stream(const resources& res) +{ + return is_stream_pool_initialized(res) ? get_stream_from_stream_pool(res) : get_cuda_stream(res); +} + +/** + * @brief return stream from pool at index if size > 0, else main stream on res + * + * @param[in] res the raft resources object + * @param[in] stream_idx the required index of the stream in the stream pool if available + */ +inline rmm::cuda_stream_view get_next_usable_stream(const resources& res, std::size_t stream_idx) +{ + return is_stream_pool_initialized(res) ? get_stream_from_stream_pool(res, stream_idx) + : get_cuda_stream(res); +} + +/** + * @brief synchronize the stream pool on the res + * + * @param[in] res the raft resources object + */ +inline void sync_stream_pool(const resources& res) +{ + for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { + sync_stream(res, get_cuda_stream_pool(res).get_stream(i)); + } +} + +/** + * @brief synchronize subset of stream pool + * + * @param[in] res the raft resources object + * @param[in] stream_indices the indices of the streams in the stream pool to synchronize + */ +inline void sync_stream_pool(const resources& res, const std::vector stream_indices) +{ + RAFT_EXPECTS(is_stream_pool_initialized(res), "ERROR: rmm::cuda_stream_pool was not initialized"); + for (const auto& stream_index : stream_indices) { + sync_stream(res, get_cuda_stream_pool(res).get_stream(stream_index)); + } +} + +/** + * @brief ask stream pool to wait on last event in main stream + * + * @param[in] res the raft resources object + */ +inline void wait_stream_pool_on_stream(const resources& res) +{ + cudaEvent_t event = detail::get_cuda_stream_sync_event(res); + RAFT_CUDA_TRY(cudaEventRecord(event, get_cuda_stream(res))); + for (std::size_t i = 0; i < get_stream_pool_size(res); i++) { + RAFT_CUDA_TRY(cudaStreamWaitEvent(get_cuda_stream_pool(res).get_stream(i), event, 0)); + } +} + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusolver_dn_handle.hpp b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp new file mode 100644 index 0000000000..7a33e2dd2a --- /dev/null +++ b/cpp/include/raft/core/resource/cusolver_dn_handle.hpp @@ -0,0 +1,85 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "cuda_stream.hpp" +#include +#include +#include +#include +#include + +namespace raft::resource { + +/** + * + */ +class cusolver_dn_resource : public resource { + public: + cusolver_dn_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnCreate(&cusolver_res)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnSetStream(cusolver_res, stream)); + } + + void* get_resource() override { return &cusolver_res; } + + ~cusolver_dn_resource() override { RAFT_CUSOLVER_TRY_NO_THROW(cusolverDnDestroy(cusolver_res)); } + + private: + cusolverDnHandle_t cusolver_res; +}; + +/** + * @defgroup resource_cusolver_dn cuSolver DN handle resource functions + * @{ + */ + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusolver_dn_resource_factory : public resource_factory { + public: + cusolver_dn_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSOLVER_DN_HANDLE; } + resource* make_resource() override { return new cusolver_dn_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * Load a cusolverSpres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cusolver dn handle + */ +inline cusolverDnHandle_t get_cusolver_dn_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSOLVER_DN_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSOLVER_DN_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusolver_sp_handle.hpp b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp new file mode 100644 index 0000000000..61fd95b44f --- /dev/null +++ b/cpp/include/raft/core/resource/cusolver_sp_handle.hpp @@ -0,0 +1,84 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +/** + * + */ +class cusolver_sp_resource : public resource { + public: + cusolver_sp_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpCreate(&cusolver_res)); + RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpSetStream(cusolver_res, stream)); + } + + void* get_resource() override { return &cusolver_res; } + + ~cusolver_sp_resource() override { RAFT_CUSOLVER_TRY_NO_THROW(cusolverSpDestroy(cusolver_res)); } + + private: + cusolverSpHandle_t cusolver_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusolver_sp_resource_factory : public resource_factory { + public: + cusolver_sp_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSOLVER_SP_HANDLE; } + resource* make_resource() override { return new cusolver_sp_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * @defgroup resource_cusolver_sp cuSolver SP handle resource functions + * @{ + */ + +/** + * Load a cusolverSpres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cusolver sp handle + */ +inline cusolverSpHandle_t get_cusolver_sp_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSOLVER_SP_HANDLE)) { + cudaStream_t stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSOLVER_SP_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/cusparse_handle.hpp b/cpp/include/raft/core/resource/cusparse_handle.hpp new file mode 100644 index 0000000000..9893ed2f86 --- /dev/null +++ b/cpp/include/raft/core/resource/cusparse_handle.hpp @@ -0,0 +1,79 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { +class cusparse_resource : public resource { + public: + cusparse_resource(rmm::cuda_stream_view stream) + { + RAFT_CUSPARSE_TRY_NO_THROW(cusparseCreate(&cusparse_res)); + RAFT_CUSPARSE_TRY_NO_THROW(cusparseSetStream(cusparse_res, stream)); + } + + ~cusparse_resource() { RAFT_CUSPARSE_TRY_NO_THROW(cusparseDestroy(cusparse_res)); } + void* get_resource() override { return &cusparse_res; } + + private: + cusparseHandle_t cusparse_res; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class cusparse_resource_factory : public resource_factory { + public: + cusparse_resource_factory(rmm::cuda_stream_view stream) : stream_(stream) {} + resource_type get_resource_type() override { return resource_type::CUSPARSE_HANDLE; } + resource* make_resource() override { return new cusparse_resource(stream_); } + + private: + rmm::cuda_stream_view stream_; +}; + +/** + * @defgroup resource_cusparse cuSparse handle resource functions + * @{ + */ + +/** + * Load a cusparseres_t from raft res if it exists, otherwise + * add it and return it. + * @param[in] res the raft resources object + * @return cusparse handle + */ +inline cusparseHandle_t get_cusparse_handle(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUSPARSE_HANDLE)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::CUSPARSE_HANDLE); +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/detail/stream_sync_event.hpp b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp new file mode 100644 index 0000000000..1d02fef20d --- /dev/null +++ b/cpp/include/raft/core/resource/detail/stream_sync_event.hpp @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource::detail { + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the res_t. + */ +class cuda_stream_sync_event_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::CUDA_STREAM_SYNC_EVENT; } + resource* make_resource() override { return new cuda_event_resource(); } +}; + +/** + * Load a cudaEvent from a resources instance (and populate it on the resources instance) + * if needed) for syncing the main cuda stream. + * @param res raft resources instance for managing resources + * @return + */ +inline cudaEvent_t& get_cuda_stream_sync_event(resources const& res) +{ + if (!res.has_resource_factory(resource_type::CUDA_STREAM_SYNC_EVENT)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::CUDA_STREAM_SYNC_EVENT); +}; + +} // namespace raft::resource::detail diff --git a/cpp/include/raft/core/resource/device_id.hpp b/cpp/include/raft/core/resource/device_id.hpp new file mode 100644 index 0000000000..b55e56ca45 --- /dev/null +++ b/cpp/include/raft/core/resource/device_id.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include + +namespace raft::resource { + +class device_id_resource : public resource { + public: + device_id_resource() + : dev_id_([]() -> int { + int cur_dev = -1; + RAFT_CUDA_TRY_NO_THROW(cudaGetDevice(&cur_dev)); + return cur_dev; + }()) + { + } + void* get_resource() override { return &dev_id_; } + + ~device_id_resource() override {} + + private: + int dev_id_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class device_id_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::DEVICE_ID; } + resource* make_resource() override { return new device_id_resource(); } +}; + +/** + * @defgroup resource_device_id Device ID resource functions + * @{ + */ + +/** + * Load a device id from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return device id + */ +inline int get_device_id(resources const& res) +{ + if (!res.has_resource_factory(resource_type::DEVICE_ID)) { + res.add_resource_factory(std::make_shared()); + } + return *res.get_resource(resource_type::DEVICE_ID); +}; + +/** + * @} + */ +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp new file mode 100644 index 0000000000..35ae3d715f --- /dev/null +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class device_memory_resource : public resource { + public: + device_memory_resource(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) + { + if (mr_ == nullptr) { mr = rmm::mr::get_current_device_resource(); } + } + void* get_resource() override { return mr; } + + ~device_memory_resource() override {} + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Factory that knows how to construct a specific raft::resource to populate + * the resources instance. + */ +class workspace_resource_factory : public resource_factory { + public: + workspace_resource_factory(rmm::mr::device_memory_resource* mr_ = nullptr) : mr(mr_) {} + resource_type get_resource_type() override { return resource_type::WORKSPACE_RESOURCE; } + resource* make_resource() override { return new device_memory_resource(mr); } + + private: + rmm::mr::device_memory_resource* mr; +}; + +/** + * Load a temp workspace resource from a resources instance (and populate it on the res + * if needed). + * @param res raft resources object for managing resources + * @return device memory resource object + */ +inline rmm::mr::device_memory_resource* get_workspace_resource(resources const& res) +{ + if (!res.has_resource_factory(resource_type::WORKSPACE_RESOURCE)) { + res.add_resource_factory(std::make_shared()); + } + return res.get_resource(resource_type::WORKSPACE_RESOURCE); +}; + +/** + * Set a temp workspace resource on a resources instance. + * + * @param res raft resources object for managing resources + * @param mr a valid rmm device_memory_resource + */ +inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_resource* mr) +{ + res.add_resource_factory(std::make_shared(mr)); +}; +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/device_properties.hpp b/cpp/include/raft/core/resource/device_properties.hpp new file mode 100644 index 0000000000..c3b0b8f2b9 --- /dev/null +++ b/cpp/include/raft/core/resource/device_properties.hpp @@ -0,0 +1,77 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft::resource { + +class device_properties_resource : public resource { + public: + device_properties_resource(int dev_id) + { + RAFT_CUDA_TRY_NO_THROW(cudaGetDeviceProperties(&prop_, dev_id)); + } + void* get_resource() override { return &prop_; } + + ~device_properties_resource() override {} + + private: + cudaDeviceProp prop_; +}; + +/** + * @defgroup resource_device_props Device properties resource functions + * @{ + */ + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class device_properties_resource_factory : public resource_factory { + public: + device_properties_resource_factory(int dev_id) : dev_id_(dev_id) {} + resource_type get_resource_type() override { return resource_type::DEVICE_PROPERTIES; } + resource* make_resource() override { return new device_properties_resource(dev_id_); } + + private: + int dev_id_; +}; + +/** + * Load a cudaDeviceProp from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return populated cuda device properties instance + */ +inline cudaDeviceProp& get_device_properties(resources const& res) +{ + if (!res.has_resource_factory(resource_type::DEVICE_PROPERTIES)) { + int dev_id = get_device_id(res); + res.add_resource_factory(std::make_shared(dev_id)); + } + return *res.get_resource(resource_type::DEVICE_PROPERTIES); +}; + +/** + * @} + */ +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/resource_types.hpp b/cpp/include/raft/core/resource/resource_types.hpp new file mode 100644 index 0000000000..cf302e25f9 --- /dev/null +++ b/cpp/include/raft/core/resource/resource_types.hpp @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::resource { + +/** + * @defgroup resource_types Core resource vocabulary types + * @{ + */ + +/** + * @brief Resource types can apply to any resource and don't have to be host- or device-specific. + */ +enum resource_type { + // device-specific resource types + CUBLAS_HANDLE = 0, // cublas handle + CUSOLVER_DN_HANDLE, // cusolver dn handle + CUSOLVER_SP_HANDLE, // cusolver sp handle + CUSPARSE_HANDLE, // cusparse handle + CUDA_STREAM_VIEW, // view of a cuda stream + CUDA_STREAM_POOL, // cuda stream pool + CUDA_STREAM_SYNC_EVENT, // cuda event for syncing streams + COMMUNICATOR, // raft communicator + SUB_COMMUNICATOR, // raft sub communicator + DEVICE_PROPERTIES, // cuda device properties + DEVICE_ID, // cuda device id + THRUST_POLICY, // thrust execution policy + WORKSPACE_RESOURCE, // rmm device memory resource + + LAST_KEY // reserved for the last key +}; + +/** + * @brief A resource constructs and contains an instance of + * some pre-determined object type and facades that object + * behind a common API. + */ +class resource { + public: + virtual void* get_resource() = 0; + + virtual ~resource() {} +}; + +class empty_resource : public resource { + public: + empty_resource() : resource() {} + + void* get_resource() override { return nullptr; } + + ~empty_resource() override {} +}; + +/** + * @brief A resource factory knows how to construct an instance of + * a specific raft::resource::resource. + */ +class resource_factory { + public: + /** + * @brief Return the resource_type associated with the current factory + * @return resource_type corresponding to the current factory + */ + virtual resource_type get_resource_type() = 0; + + /** + * @brief Construct an instance of the factory's underlying resource. + * @return resource instance + */ + virtual resource* make_resource() = 0; +}; + +/** + * @brief A resource factory knows how to construct an instance of + * a specific raft::resource::resource. + */ +class empty_resource_factory : public resource_factory { + public: + empty_resource_factory() : resource_factory() {} + /** + * @brief Return the resource_type associated with the current factory + * @return resource_type corresponding to the current factory + */ + resource_type get_resource_type() override { return resource_type::LAST_KEY; } + + /** + * @brief Construct an instance of the factory's underlying resource. + * @return resource instance + */ + resource* make_resource() override { return &res; } + + private: + empty_resource res; +}; + +/** + * @} + */ + +} // namespace raft::resource diff --git a/cpp/include/raft/core/resource/sub_comms.hpp b/cpp/include/raft/core/resource/sub_comms.hpp new file mode 100644 index 0000000000..7070b61c54 --- /dev/null +++ b/cpp/include/raft/core/resource/sub_comms.hpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include + +namespace raft::resource { +class sub_comms_resource : public resource { + public: + sub_comms_resource() : communicators_() {} + void* get_resource() override { return &communicators_; } + + ~sub_comms_resource() override {} + + private: + std::unordered_map> communicators_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class sub_comms_resource_factory : public resource_factory { + public: + resource_type get_resource_type() override { return resource_type::SUB_COMMUNICATOR; } + resource* make_resource() override { return new sub_comms_resource(); } +}; + +/** + * @defgroup resource_subcomms Subcommunicator resource functions + * @{ + */ + +inline const comms::comms_t& get_subcomm(const resources& res, std::string key) +{ + if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { + res.add_resource_factory(std::make_shared()); + } + + auto sub_comms = + res.get_resource>>( + resource_type::SUB_COMMUNICATOR); + auto sub_comm = sub_comms->at(key); + RAFT_EXPECTS(nullptr != sub_comm.get(), "ERROR: Subcommunicator was not initialized"); + + return *sub_comm; +} + +inline void set_subcomm(resources const& res, + std::string key, + std::shared_ptr subcomm) +{ + if (!res.has_resource_factory(resource_type::SUB_COMMUNICATOR)) { + res.add_resource_factory(std::make_shared()); + } + auto sub_comms = + res.get_resource>>( + resource_type::SUB_COMMUNICATOR); + sub_comms->insert(std::make_pair(key, subcomm)); +} + +/** + * @} + */ + +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resource/thrust_policy.hpp b/cpp/include/raft/core/resource/thrust_policy.hpp new file mode 100644 index 0000000000..1e7441e5e4 --- /dev/null +++ b/cpp/include/raft/core/resource/thrust_policy.hpp @@ -0,0 +1,74 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +namespace raft::resource { +class thrust_policy_resource : public resource { + public: + thrust_policy_resource(rmm::cuda_stream_view stream_view) + : thrust_policy_(std::make_unique(stream_view)) + { + } + void* get_resource() override { return thrust_policy_.get(); } + + ~thrust_policy_resource() override {} + + private: + std::unique_ptr thrust_policy_; +}; + +/** + * Factory that knows how to construct a + * specific raft::resource to populate + * the res_t. + */ +class thrust_policy_resource_factory : public resource_factory { + public: + thrust_policy_resource_factory(rmm::cuda_stream_view stream_view) : stream_view_(stream_view) {} + resource_type get_resource_type() override { return resource_type::THRUST_POLICY; } + resource* make_resource() override { return new thrust_policy_resource(stream_view_); } + + private: + rmm::cuda_stream_view stream_view_; +}; + +/** + * @defgroup resource_thrust_policy Thrust policy resource functions + * @{ + */ + +/** + * Load a thrust policy from a res (and populate it on the res if needed). + * @param res raft res object for managing resources + * @return thrust execution policy + */ +inline rmm::exec_policy& get_thrust_policy(resources const& res) +{ + if (!res.has_resource_factory(resource_type::THRUST_POLICY)) { + rmm::cuda_stream_view stream = get_cuda_stream(res); + res.add_resource_factory(std::make_shared(stream)); + } + return *res.get_resource(resource_type::THRUST_POLICY); +}; + +/** + * @} + */ + +} // namespace raft::resource \ No newline at end of file diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp new file mode 100644 index 0000000000..64e281e934 --- /dev/null +++ b/cpp/include/raft/core/resources.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "resource/resource_types.hpp" +#include +#include +#include +#include +#include + +namespace raft { + +/** + * @brief Resource container which allows lazy-loading and registration + * of resource_factory implementations, which in turn generate resource instances. + * + * This class is intended to be agnostic of the resources it contains and + * does not, itself, differentiate between host and device resources. Downstream + * accessor functions can then register and load resources as needed in order + * to keep its usage somewhat opaque to end-users. + * + * @code{.cpp} + * #include + * #include + * #include + * + * raft::resources res; + * auto stream = raft::resource::get_cuda_stream(res); + * auto cublas_handle = raft::resource::get_cublas_handle(res); + * @endcode + */ +class resources { + public: + template + using pair_res = std::pair>; + + using pair_res_factory = pair_res; + using pair_resource = pair_res; + + resources() + : factories_(resource::resource_type::LAST_KEY), resources_(resource::resource_type::LAST_KEY) + { + for (int i = 0; i < resource::resource_type::LAST_KEY; ++i) { + factories_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, + std::make_shared()); + resources_.at(i) = std::make_pair(resource::resource_type::LAST_KEY, + std::make_shared()); + } + } + + /** + * @brief Shallow copy of underlying resources instance. + * Note that this does not create any new resources. + */ + resources(const resources& res) : factories_(res.factories_), resources_(res.resources_) {} + resources(resources&&) = delete; + resources& operator=(resources&&) = delete; + + /** + * @brief Returns true if a resource_factory has been registered for the + * given resource_type, false otherwise. + * @param resource_type resource type to check + * @return true if resource_factory is registered for the given resource_type + */ + bool has_resource_factory(resource::resource_type resource_type) const + { + std::lock_guard _(mutex_); + return factories_.at(resource_type).first != resource::resource_type::LAST_KEY; + } + + /** + * @brief Register a resource_factory with the current instance. + * This will overwrite any existing resource factories. + * @param factory resource factory to register on the current instance + */ + void add_resource_factory(std::shared_ptr factory) const + { + std::lock_guard _(mutex_); + resource::resource_type rtype = factory.get()->get_resource_type(); + RAFT_EXPECTS(rtype != resource::resource_type::LAST_KEY, + "LAST_KEY is a placeholder and not a valid resource factory type."); + factories_.at(rtype) = std::make_pair(rtype, factory); + } + + /** + * @brief Retrieve a resource for the given resource_type and cast to given pointer type. + * Note that the resources are loaded lazily on-demand and resources which don't yet + * exist on the current instance will be created using the corresponding factory, if + * it exists. + * @tparam res_t pointer type for which retrieved resource will be casted + * @param resource_type resource type to retrieve + * @return the given resource, if it exists. + */ + template + res_t* get_resource(resource::resource_type resource_type) const + { + std::lock_guard _(mutex_); + + if (resources_.at(resource_type).first == resource::resource_type::LAST_KEY) { + RAFT_EXPECTS(factories_.at(resource_type).first != resource::resource_type::LAST_KEY, + "No resource factory has been registered for the given resource %d.", + resource_type); + resource::resource_factory* factory = factories_.at(resource_type).second.get(); + resources_.at(resource_type) = std::make_pair( + resource_type, std::shared_ptr(factory->make_resource())); + } + + resource::resource* res = resources_.at(resource_type).second.get(); + return reinterpret_cast(res->get_resource()); + } + + protected: + mutable std::mutex mutex_; + mutable std::vector factories_; + mutable std::vector resources_; +}; +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/distance/detail/canberra.cuh b/cpp/include/raft/distance/detail/canberra.cuh index 90ed3940e1..f17a26dc4b 100644 --- a/cpp/include/raft/distance/detail/canberra.cuh +++ b/cpp/include/raft/distance/detail/canberra.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,19 +73,15 @@ static void canberraImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - const auto add = raft::myAbs(x) + raft::myAbs(y); + const auto diff = raft::abs(x - y); + const auto add = raft::abs(x) + raft::abs(y); // deal with potential for 0 in denominator by // forcing 1/0 instead acc += ((add != 0) * diff / (add + (add == 0))); }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto canberraRowMajor = pairwiseDistanceMatKernel #include namespace raft { @@ -72,16 +73,12 @@ static void chebyshevImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - acc = raft::myMax(acc, diff); + const auto diff = raft::abs(x - y); + acc = raft::max(acc, diff); }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto chebyshevRowMajor = pairwiseDistanceMatKernel +#include +#include + +namespace raft::distance::detail { + +/** + * @brief Compress 2D boolean matrix to bitfield + * + * Utility kernel for maskedL2NN. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ +template ::value>> +__global__ void compress_to_bits_kernel( + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + constexpr int bits_per_element = 8 * sizeof(T); + constexpr int tile_dim_m = bits_per_element; + constexpr int nthreads = 128; + constexpr int tile_dim_n = nthreads; // read 128 bools at once = 1 sector + + // Tile in shared memory is transposed + __shared__ bool smem[tile_dim_n][tile_dim_m]; + + const int num_tiles_per_m = raft::ceildiv(in.extent(0), tile_dim_m); + const int num_tiles_per_n = raft::ceildiv(in.extent(1), tile_dim_n); + + for (int lin_tile_idx = blockIdx.x; true; lin_tile_idx += gridDim.x) { + const int tile_idx_n = tile_dim_n * (lin_tile_idx % num_tiles_per_n); + const int tile_idx_m = tile_dim_m * (lin_tile_idx / num_tiles_per_n); + + if (in.extent(0) <= tile_idx_m) { break; } + // Fill shared memory tile + bool reg_buf[tile_dim_m]; +#pragma unroll + for (int i = 0; i < tile_dim_m; ++i) { + const int in_m = tile_idx_m + i; + const int in_n = tile_idx_n + threadIdx.x; + bool in_bounds = in_m < in.extent(0) && in_n < in.extent(1); + reg_buf[i] = in_bounds ? in(in_m, in_n) : false; + smem[threadIdx.x][i] = reg_buf[i]; + } + __syncthreads(); + + // Drain memory tile into single output element out_elem. + T out_elem{0}; +#pragma unroll + for (int j = 0; j < tile_dim_n; ++j) { + if (smem[threadIdx.x][j]) { out_elem |= T(1) << j; } + } + __syncthreads(); + + // Write output. + int out_m = tile_idx_m / bits_per_element; + int out_n = tile_idx_n + threadIdx.x; + + if (out_m < out.extent(0) && out_n < out.extent(1)) { out(out_m, out_n) = out_elem; } + } +} + +/** + * @brief Compress 2D boolean matrix to bitfield + * + * Utility kernel for maskedL2NN. + * + * @tparam T + * + * @parameter[in] in An `m x n` boolean matrix. Row major. + * @parameter[out] out An `(m / bits_per_elem) x n` matrix with elements of + * type T, where T is of size `bits_per_elem` bits. + * Note: the division (`/`) is a ceilDiv. + */ +template ::value>> +void compress_to_bits(raft::device_resources const& handle, + raft::device_matrix_view in, + raft::device_matrix_view out) +{ + auto stream = handle.get_stream(); + constexpr int bits_per_element = 8 * sizeof(T); + + RAFT_EXPECTS(raft::ceildiv(in.extent(0), bits_per_element) == out.extent(0), + "Number of output rows must be ceildiv(input rows, bits_per_elem)"); + RAFT_EXPECTS(in.extent(1) == out.extent(1), "Number of output columns must equal input columns."); + + const int num_SMs = raft::getMultiProcessorCount(); + int blocks_per_sm = 0; + constexpr int num_threads = 128; + constexpr int dyn_smem_size = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &blocks_per_sm, compress_to_bits_kernel, num_threads, dyn_smem_size)); + + dim3 grid(num_SMs * blocks_per_sm); + dim3 block(128); + compress_to_bits_kernel<<>>(in, out); + RAFT_CUDA_TRY(cudaGetLastError()); +} + +}; // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/detail/correlation.cuh b/cpp/include/raft/distance/detail/correlation.cuh index 9bdbbf112c..f7fe3678e6 100644 --- a/cpp/include/raft/distance/detail/correlation.cuh +++ b/cpp/include/raft/distance/detail/correlation.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -125,7 +125,7 @@ static void correlationImpl(const DataT* x, auto Q_denom = k * regx2n[i] - (regxn[i] * regxn[i]); auto R_denom = k * regy2n[j] - (regyn[j] * regyn[j]); - acc[i][j] = 1 - (numer / raft::mySqrt(Q_denom * R_denom)); + acc[i][j] = 1 - (numer / raft::sqrt(Q_denom * R_denom)); } } }; diff --git a/cpp/include/raft/distance/detail/euclidean.cuh b/cpp/include/raft/distance/detail/euclidean.cuh index 4184810fff..1a2db63f5c 100644 --- a/cpp/include/raft/distance/detail/euclidean.cuh +++ b/cpp/include/raft/distance/detail/euclidean.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ struct L2ExpandedOp { __device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept { AccT outVal = aNorm + bNorm - DataT(2.0) * accVal; - return sqrt ? raft::mySqrt(outVal) : outVal; + return sqrt ? raft::sqrt(outVal) : outVal; } __device__ AccT operator()(DataT aData) const noexcept { return aData; } @@ -130,7 +130,7 @@ void euclideanExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } @@ -350,7 +350,7 @@ void euclideanUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(acc[i][j]); + acc[i][j] = raft::sqrt(acc[i][j]); } } } diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 5311a26d19..4f5e224a19 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -39,6 +39,7 @@ template struct KVPMinReduceImpl { typedef raft::KeyValuePair KVP; DI KVP operator()(LabelT rit, const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } + DI KVP operator()(const KVP& a, const KVP& b) { return b.value < a.value ? b : a; } }; // KVPMinReduce @@ -185,7 +186,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, #pragma unroll for (int j = 0; j < P::AccColsPerTh; ++j) { auto acc_ij = acc[i][j]; - acc[i][j] = acc_ij > DataT{0} ? raft::mySqrt(acc_ij) : DataT{0}; + acc[i][j] = acc_ij > DataT{0} ? raft::sqrt(acc_ij) : DataT{0}; } } } diff --git a/cpp/include/raft/distance/detail/hellinger.cuh b/cpp/include/raft/distance/detail/hellinger.cuh index 51f462ab36..13507fe84f 100644 --- a/cpp/include/raft/distance/detail/hellinger.cuh +++ b/cpp/include/raft/distance/detail/hellinger.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,7 +105,7 @@ static void hellingerImpl(const DataT* x, // Adjust to replace NaN in sqrt with 0 if input to sqrt is negative const auto finalVal = (1 - acc[i][j]); const auto rectifier = (!signbit(finalVal)); - acc[i][j] = raft::mySqrt(rectifier * finalVal); + acc[i][j] = raft::sqrt(rectifier * finalVal); } } }; diff --git a/cpp/include/raft/distance/detail/jensen_shannon.cuh b/cpp/include/raft/distance/detail/jensen_shannon.cuh index 92ee071cf5..f96da01b87 100644 --- a/cpp/include/raft/distance/detail/jensen_shannon.cuh +++ b/cpp/include/raft/distance/detail/jensen_shannon.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -78,11 +78,11 @@ static void jensenShannonImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { const DataT m = 0.5f * (x + y); const bool m_zero = (m == 0); - const auto logM = (!m_zero) * raft::myLog(m + m_zero); + const auto logM = (!m_zero) * raft::log(m + m_zero); const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += (-x * (logM - raft::myLog(x + x_zero))) + (-y * (logM - raft::myLog(y + y_zero))); + acc += (-x * (logM - raft::log(x + x_zero))) + (-y * (logM - raft::log(y + y_zero))); }; // epilogue operation lambda for final value calculation @@ -95,7 +95,7 @@ static void jensenShannonImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::mySqrt(0.5 * acc[i][j]); + acc[i][j] = raft::sqrt(0.5 * acc[i][j]); } } }; diff --git a/cpp/include/raft/distance/detail/kl_divergence.cuh b/cpp/include/raft/distance/detail/kl_divergence.cuh index 4c0c4b6ace..7ebeaf4de9 100644 --- a/cpp/include/raft/distance/detail/kl_divergence.cuh +++ b/cpp/include/raft/distance/detail/kl_divergence.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -81,10 +81,10 @@ static void klDivergenceImpl(const DataT* x, auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { if (isRowMajor) { const bool x_zero = (x == 0); - acc += x * (raft::myLog(x + x_zero) - y); + acc += x * (raft::log(x + x_zero) - y); } else { const bool y_zero = (y == 0); - acc += y * (raft::myLog(y + y_zero) - x); + acc += y * (raft::log(y + y_zero) - x); } }; @@ -92,23 +92,23 @@ static void klDivergenceImpl(const DataT* x, if (isRowMajor) { const bool x_zero = (x == 0); const bool y_zero = (y == 0); - acc += x * (raft::myLog(x + x_zero) - (!y_zero) * raft::myLog(y + y_zero)); + acc += x * (raft::log(x + x_zero) - (!y_zero) * raft::log(y + y_zero)); } else { const bool y_zero = (y == 0); const bool x_zero = (x == 0); - acc += y * (raft::myLog(y + y_zero) - (!x_zero) * raft::myLog(x + x_zero)); + acc += y * (raft::log(y + y_zero) - (!x_zero) * raft::log(x + x_zero)); } }; auto unaryOp_lambda = [] __device__(DataT input) { const bool x_zero = (input == 0); - return (!x_zero) * raft::myLog(input + x_zero); + return (!x_zero) * raft::log(input + x_zero); }; auto unaryOp_lambda_reverse = [] __device__(DataT input) { // reverse previous log (x) back to x using (e ^ log(x)) const bool x_zero = (input == 0); - return (!x_zero) * raft::myExp(input); + return (!x_zero) * raft::exp(input); }; // epilogue operation lambda for final value calculation diff --git a/cpp/include/raft/distance/detail/l1.cuh b/cpp/include/raft/distance/detail/l1.cuh index 95514db60b..bf10651b60 100644 --- a/cpp/include/raft/distance/detail/l1.cuh +++ b/cpp/include/raft/distance/detail/l1.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,16 +71,12 @@ static void l1Impl(const DataT* x, // Accumulation operation lambda auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); + const auto diff = raft::abs(x - y); acc += diff; }; // epilogue operation lambda for final value calculation - auto epilog_lambda = [] __device__(AccT acc[KPolicy::AccRowsPerTh][KPolicy::AccColsPerTh], - DataT * regxn, - DataT * regyn, - IdxT gridStrideX, - IdxT gridStrideY) { return; }; + auto epilog_lambda = raft::void_op(); if (isRowMajor) { auto l1RowMajor = pairwiseDistanceMatKernel +#include +#include + +#include + +namespace raft { +namespace distance { +namespace detail { + +/** + * @brief Device class for masked nearest neighbor computations. + * + * @tparam useNorms whether norms are needed + * @tparam DataT input data-type (for x and y matrices) + * @tparam AccT accumulation data-type + * @tparam IdxT index data-type + * @tparam Policy struct which tunes the Contraction kernel + * @tparam CoreLambda tells how to accumulate an x and y into + acc. its signature: + template void core_lambda(AccT& acc, + const DataT& x, const DataT& y) + * @tparam EpilogueLambda applies an elementwise function to compute final + values. Its signature is: + template void epilogue_lambda + (AccT acc[][], DataT* regxn, DataT* regyn); + * @tparam FinalLambda the final lambda called on final distance value + * @tparam rowEpilogueLambda epilog lambda that executes when a full row has + * been processed. + * + * @param[in] x input matrix + * @param[in] y input matrix + * @param[in] m number of rows of x + * @param[in] n number of columns of y + * @param[in] k number of cols of x and y + * @param[in] lda leading dimension of x + * @param[in] ldb leading dimension of y + * @param[in] ldd parameter to keep Contractions_NT happy.. + * @param[in] xn row norms of input matrix A. Required for expanded L2, cosine + * @param[in] yn row norms of input matrix B. Required for expanded L2, cosine + * @param[in] adj An adjacency matrix encoded as a bitfield indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `(m / 64) x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups The number of groups in group_idxs. + * @param[in] smem shared mem buffer for intermediate storage of x, y, xn & yn. + * @param core_op the core accumulation operation lambda + * @param epilog_op the epilog operation lambda + * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed. + */ +template > +struct MaskedDistances : public BaseClass { + private: + typedef Policy P; + const DataT* xn; + const DataT* yn; + const DataT* const yBase; + const uint64_t* adj; + const IdxT* group_idxs; + IdxT num_groups; + char* smem; + CoreLambda core_op; + EpilogueLambda epilog_op; + FinalLambda fin_op; + rowEpilogueLambda rowEpilog_op; + + AccT acc[P::AccRowsPerTh][P::AccColsPerTh]; + + public: + // Constructor + DI MaskedDistances(const DataT* _x, + const DataT* _y, + IdxT _m, + IdxT _n, + IdxT _k, + IdxT _lda, + IdxT _ldb, + IdxT _ldd, + const DataT* _xn, + const DataT* _yn, + const uint64_t* _adj, + const IdxT* _group_idxs, + IdxT _num_groups, + char* _smem, + CoreLambda _core_op, + EpilogueLambda _epilog_op, + FinalLambda _fin_op, + rowEpilogueLambda _rowEpilog_op) + : BaseClass(_x, _y, _m, _n, _k, _lda, _ldb, _ldd, _smem), + xn(_xn), + yn(_yn), + yBase(_y), + adj(_adj), + group_idxs(_group_idxs), + num_groups(_num_groups), + smem(_smem), + core_op(_core_op), + epilog_op(_epilog_op), + fin_op(_fin_op), + rowEpilog_op(_rowEpilog_op) + { + } + + DI void run() + { + const auto grid_stride_m = (P::Mblk * gridDim.y); + const auto grid_offset_m = (P::Mblk * blockIdx.y); + + const auto grid_stride_g = gridDim.x; + const auto grid_offset_g = blockIdx.x; + + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + // Start loop over groups + for (auto idx_g = grid_offset_g; idx_g < this->num_groups; idx_g += grid_stride_g) { + const uint64_t block_adj = get_block_adjacency(adj, tile_idx_m, idx_g); + // block_adj is a bitfield that contains a 1 if a row is adjacent to the + // current group. All zero means we can skip this group. + if (block_adj == 0) { continue; } + + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). That is, + // for i = 0,.., AccRowsPerTh and j = 0,.., AccColsPerTh: + // + // ((1 << i) & thread_adj) > 0 <=> acc[i][j] must be computed. + // + // We precompute this information because it is used in various + // locations to skip thread-local computations, specifically: + // + // 1. To skip computations if thread_adj == 0, i.e., none of the values + // of `acc` have to be computed. + // + // 2. In epilog_op, to consider only values of `acc` to be reduced that + // are not masked of. + // + // Note 1: Even when the computation can be skipped for a specific thread, + // the thread still participates in synchronization operations. + // + // Note 2: In theory, it should be possible to skip computations for + // specific rows of `acc`. In practice, however, this does not improve + // performance. + int thread_adj = compute_thread_adjacency(block_adj); + + auto tile_idx_n = idx_g == 0 ? 0 : group_idxs[idx_g - 1]; + const auto group_end_n = group_idxs[idx_g]; + for (; tile_idx_n < group_end_n; tile_idx_n += P::Nblk) { + // We provide group_end_n to limit the number of unnecessary data + // points that are loaded from y. + this->ldgXY(tile_idx_m, tile_idx_n, 0, group_end_n); + + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx, group_end_n); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + if (thread_adj != 0) { accumulate(); } + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + if (thread_adj != 0) { + accumulate(); // last iteration + } + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer + // back so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, group_end_n, regxn, regyn); + if (thread_adj != 0) { + epilog_op(acc, thread_adj, regxn, regyn, tile_idx_n, tile_idx_m, group_end_n); + } + } else { + if (thread_adj != 0) { + epilog_op(acc, thread_adj, nullptr, nullptr, tile_idx_n, tile_idx_m, group_end_n); + } + } + } // tile_idx_n + } // idx_g + rowEpilog_op(tile_idx_m); + } // tile_idx_m + } + + private: + DI uint64_t get_block_adjacency(const uint64_t* adj, IdxT tile_idx_m, IdxT idx_group) + { + // A single element of `adj` contains exactly enough bits to indicate which + // rows in the current tile to skip and which to compute. + static_assert(P::Mblk == 8 * sizeof(adj[0]), + "maskedL2NN only supports a policy with 64 rows per block."); + IdxT block_flag_idx = tile_idx_m / P::Mblk; + // Index into adj at row tile_idx_m / 64 and column idx_group. + return adj[block_flag_idx * this->num_groups + idx_group]; + } + + DI uint32_t compute_thread_adjacency(const uint64_t block_adj) + { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the run() method. + uint32_t thread_adj = 0; +#pragma unroll + for (int thread_row_idx = 0; thread_row_idx < P::AccRowsPerTh; ++thread_row_idx) { + // Index `thread_row_idx` refers to a row of the current threads' register + // tile `acc`, i.e., acc[i][:]. Index `block_row_idx` refers to the + // corresponding row of the current block tile in shared memory. + const int block_row_idx = this->accrowid + thread_row_idx * P::AccThRows; + + // block_row_is_adjacent is true if the current block_row_idx is adjacent + // to the current group. + const uint64_t block_mask = 1ull << block_row_idx; + const bool block_row_is_adjacent = (block_adj & block_mask) != 0; + if (block_row_is_adjacent) { + // If block row is adjacent, write a 1 bit to thread_adj at location + // `thread_row_idx`. + const uint32_t thread_mask = 1 << thread_row_idx; + thread_adj |= thread_mask; + } + } + return thread_adj; + } + + DI void reset_accumulator() + { + // Reset accumulator registers to zero. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = BaseClass::Zero; + } + } + } + + DI void accumulate() + { +#pragma unroll + for (int ki = 0; ki < P::Kblk; ki += P::Veclen) { + this->ldsXY(ki); +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { +#pragma unroll + for (int v = 0; v < P::Veclen; ++v) { + core_op(acc[i][j], this->regx[i][v], this->regy[j][v]); + } + } + } + } + } + + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + IdxT end_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) + { + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; + } + + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < end_n ? yn[idx] : 0; + } + __syncthreads(); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } +#pragma unroll + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; + } + } +}; // struct MaskedDistances + +}; // namespace detail +}; // namespace distance +}; // namespace raft diff --git a/cpp/include/raft/distance/detail/masked_nn.cuh b/cpp/include/raft/distance/detail/masked_nn.cuh new file mode 100644 index 0000000000..1c92de16fc --- /dev/null +++ b/cpp/include/raft/distance/detail/masked_nn.cuh @@ -0,0 +1,325 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +namespace detail { + +template +__global__ __launch_bounds__(P::Nthreads, 2) void maskedL2NNkernel(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const uint64_t* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + bool sqrt, + DataT maxVal, + int* mutex, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + CoreLambda core_op, + FinalLambda fin_op) +{ + extern __shared__ char smem[]; + + typedef raft::KeyValuePair KVPair; + KVPair val[P::AccRowsPerTh]; +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + + // epilogue operation lambda for final value calculation + auto epilog_lambda = [pairRedOp, &val, maxVal, sqrt] __device__( + DataT acc[P::AccRowsPerTh][P::AccColsPerTh], + int thread_adj, + DataT* regxn, + DataT* regyn, + IdxT tile_idx_n, + IdxT tile_idx_m, + IdxT tile_end_n) { + KVPReduceOpT pairRed_op(pairRedOp); + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; + } + } + if (sqrt) { +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + acc[i][j] = raft::sqrt(acc[i][j]); + } + } + } + + // intra thread reduce + const auto acccolid = threadIdx.x % P::AccThCols; + const auto accrowid = threadIdx.x / P::AccThCols; + +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + // thread_adj is a bitfield that contains a 1 at location i iff we must + // compute row i of acc (the accumulator register tile). It is described in + // more detail in the maskedDistances.run() method. + const bool ignore = (thread_adj & (1 << i)) == 0; + if (ignore) { continue; } +#pragma unroll + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto tmpkey = acccolid + j * P::AccThCols + tile_idx_n; + if (tile_end_n <= tmpkey) { + // Do not process beyond end of tile. + continue; + } + KVPair tmp = {tmpkey, acc[i][j]}; + if (tmpkey < tile_end_n) { + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + } + }; + + auto rowEpilog_lambda = + [m, mutex, min, pairRedOp, redOp, &val, maxVal] __device__(IdxT tile_idx_m) { + KVPReduceOpT pairRed_op(pairRedOp); + ReduceOpT red_op(redOp); + + const auto accrowid = threadIdx.x / P::AccThCols; + const auto lid = raft::laneId(); + // reduce +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { +#pragma unroll + for (int j = P::AccThCols / 2; j > 0; j >>= 1) { + auto tmpkey = raft::shfl(val[i].key, lid + j); + auto tmpvalue = raft::shfl(val[i].value, lid + j); + KVPair tmp = {tmpkey, tmpvalue}; + val[i] = pairRed_op(accrowid + i * P::AccThRows + tile_idx_m, tmp, val[i]); + } + } + + updateReducedVal(mutex, min, val, red_op, m, tile_idx_m); + + // reset the val array. +#pragma unroll + for (int i = 0; i < P::AccRowsPerTh; ++i) { + val[i] = {-1, maxVal}; + } + }; + + IdxT lda = k, ldb = k, ldd = n; + MaskedDistances + obj(x, + y, + m, + n, + k, + lda, + ldb, + ldd, + xn, + yn, + adj, + group_idxs, + num_groups, + smem, + core_op, + epilog_lambda, + fin_op, + rowEpilog_lambda); + obj.run(); +} + +/** + * @brief Wrapper for maskedL2NNkernel + * + * Responsibilities: + * - Allocate (and initialize) workspace memory for: + * - mutexes used in nearest neighbor update step + * - adjacency matrix bitfield + * - Compress adjacency matrix to bitfield + * - Initialize output buffer (conditional on `initOutBuffer`) + * - Specify core and final operations for the L2 norm + * - Determine optimal launch configuration for kernel. + * - Launch kernel and check for errors. + * + * @tparam DataT Input data-type (for x and y matrices). + * @tparam OutT Output data-type (for key-value pairs). + * @tparam IdxT Index data-type. + * @tparam ReduceOpT A struct to perform the final needed reduction + * operation and also to initialize the output array + * elements with the appropriate initial value needed for + * reduction. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * @param handle RAFT handle for managing expensive resources + * @param[out] out Will contain reduced output (nn key-value pairs) + * @param[in] x First matrix. Row major. Dim = `m x k`. (on device) + * @param[in] y Second matrix. Row major. Dim = `n x k`. (on device) + * @param[in] xn L2 squared norm of `x`. Length = `m`. + * @param[in] yn L2 squared norm of `y`. Length = `n`. + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[in] num_groups Length of `group_idxs`. + * @param m Rows of `x`. + * @param n Rows of `y`. + * @param k Cols of `x` and `y`. + * @param redOp Reduction operator in the epilogue + * @param pairRedOp Reduction operation on key value pairs + * @param sqrt Whether to compute the squared or actual (i.e. sqrt) L2 norm. + * @param initOutBuffer Whether to initialize the output buffer + * + * + */ +template +void maskedL2NNImpl(raft::device_resources const& handle, + OutT* out, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + const bool* adj, + const IdxT* group_idxs, + IdxT num_groups, + IdxT m, + IdxT n, + IdxT k, + ReduceOpT redOp, + KVPReduceOpT pairRedOp, + bool sqrt, + bool initOutBuffer) +{ + typedef typename linalg::Policy4x4::Policy P; + + static_assert(P::Mblk == 64, "maskedL2NNImpl only supports a policy with 64 rows per block."); + + // Get stream and workspace memory resource + rmm::mr::device_memory_resource* ws_mr = + dynamic_cast(handle.get_workspace_resource()); + auto stream = handle.get_stream(); + + // Acquire temporary buffers and initialize to zero: + // 1) Adjacency matrix bitfield + // 2) Workspace for fused nearest neighbor operation + size_t m_div_64 = raft::ceildiv(m, IdxT(64)); + rmm::device_uvector ws_adj64{m_div_64 * num_groups, stream, ws_mr}; + rmm::device_uvector ws_fused_nn{size_t(m), stream, ws_mr}; + RAFT_CUDA_TRY(cudaMemsetAsync(ws_adj64.data(), 0, ws_adj64.size() * sizeof(uint64_t), stream)); + RAFT_CUDA_TRY(cudaMemsetAsync(ws_fused_nn.data(), 0, ws_fused_nn.size() * sizeof(int), stream)); + + // Compress boolean adjacency matrix to bitfield. + auto adj_view = raft::make_device_matrix_view(adj, m, num_groups); + auto adj64_view = + raft::make_device_matrix_view(ws_adj64.data(), m_div_64, num_groups); + compress_to_bits(handle, adj_view, adj64_view); + + // Initialize output buffer with keyvalue pairs as determined by the reduction + // operator (it will be called with maxVal). + constexpr auto maxVal = std::numeric_limits::max(); + if (initOutBuffer) { + dim3 grid(raft::ceildiv(m, P::Nthreads)); + dim3 block(P::Nthreads); + + initKernel<<>>(out, m, maxVal, redOp); + RAFT_CUDA_TRY(cudaGetLastError()); + } + + // Accumulation operation lambda + auto core_lambda = [] __device__(DataT & acc, DataT & x, DataT & y) { acc += x * y; }; + auto fin_op = raft::identity_op{}; + + auto kernel = maskedL2NNkernel; + constexpr size_t smemSize = P::SmemSize + ((P::Mblk + P::Nblk) * sizeof(DataT)); + dim3 block(P::Nthreads); + dim3 grid = launchConfigGenerator

(m, n, smemSize, kernel); + + kernel<<>>(out, + x, + y, + xn, + yn, + ws_adj64.data(), + group_idxs, + num_groups, + m, + n, + k, + sqrt, + maxVal, + ws_fused_nn.data(), + redOp, + pairRedOp, + core_lambda, + fin_op); + + RAFT_CUDA_TRY(cudaGetLastError()); +} + +} // namespace detail +} // namespace distance +} // namespace raft diff --git a/cpp/include/raft/distance/detail/minkowski.cuh b/cpp/include/raft/distance/detail/minkowski.cuh index bda83babf1..42af8cd281 100644 --- a/cpp/include/raft/distance/detail/minkowski.cuh +++ b/cpp/include/raft/distance/detail/minkowski.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -74,8 +74,8 @@ void minkowskiUnExpImpl(const DataT* x, // Accumulation operation lambda auto core_lambda = [p] __device__(AccT & acc, DataT & x, DataT & y) { - const auto diff = raft::myAbs(x - y); - acc += raft::myPow(diff, p); + const auto diff = raft::abs(x - y); + acc += raft::pow(diff, p); }; // epilogue operation lambda for final value calculation @@ -89,7 +89,7 @@ void minkowskiUnExpImpl(const DataT* x, for (int i = 0; i < KPolicy::AccRowsPerTh; ++i) { #pragma unroll for (int j = 0; j < KPolicy::AccColsPerTh; ++j) { - acc[i][j] = raft::myPow(acc[i][j], one_over_p); + acc[i][j] = raft::pow(acc[i][j], one_over_p); } } }; diff --git a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh index 69bb83d29a..d849b23999 100644 --- a/cpp/include/raft/distance/detail/pairwise_distance_base.cuh +++ b/cpp/include/raft/distance/detail/pairwise_distance_base.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -59,6 +59,7 @@ namespace detail { * @param core_op the core accumulation operation lambda * @param epilog_op the epilog operation lambda * @param fin_op the final gemm epilogue lambda + * @param rowEpilog_op epilog lambda that executes when a full row has been processed */ template m; - gridStrideY += P::Mblk * gridDim.y) { - for (auto gridStrideX = blockIdx.x * P::Nblk; gridStrideX < this->n; - gridStrideX += P::Nblk * gridDim.x) { - prolog(gridStrideX, gridStrideY); - loop(); - epilog(gridStrideX, gridStrideY); + for (auto tile_idx_m = grid_offset_m; tile_idx_m < this->m; tile_idx_m += grid_stride_m) { + this->ldgXY(tile_idx_m, grid_offset_n, 0); + for (auto tile_idx_n = grid_offset_n; tile_idx_n < this->n; tile_idx_n += grid_stride_n) { + // Prolog: + reset_accumulator(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + + // Main loop: + for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { + this->ldgXY(tile_idx_m, tile_idx_n, kidx); + // Process all data in shared memory (previous k-block) and + // accumulate in registers. + accumulate(); + this->stsXY(); + __syncthreads(); + this->switch_write_buffer(); + this->switch_read_buffer(); + } + accumulate(); // last iteration + // The pre-condition for the loop over tile_idx_n is that write_buffer + // and read_buffer point to the same buffer. This flips read_buffer back + // so that it satisfies the pre-condition of this loop. + this->switch_read_buffer(); + + // Epilog: + if (useNorms) { + DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; + load_norms(tile_idx_m, tile_idx_n, regxn, regyn); + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, regxn, regyn, tile_idx_n, tile_idx_m); + } else { + // Overlap ldg with epilog computation + ldgNextGridStride(tile_idx_m, tile_idx_n); + epilog_op(acc, nullptr, nullptr, tile_idx_n, tile_idx_m); + } + if (writeOut) { store_output(tile_idx_m, tile_idx_n); } } - rowEpilog_op(gridStrideY); + rowEpilog_op(tile_idx_m); } } private: - DI void updateIndicesY() - { - const auto stride = P::Nblk * gridDim.x; - if (isRowMajor) { - this->y += stride * this->ldb; - } else { - this->y += stride; - } - this->yrowid += stride; - } - - DI void updateIndicesXY() - { - const auto stride = P::Mblk * gridDim.y; - if (isRowMajor) { - this->x += stride * this->lda; - this->yrowid = IdxT(blockIdx.x) * P::Nblk + this->srowid; - this->y = yBase + this->yrowid * this->ldb; - } else { - this->x += stride; - this->yrowid = IdxT(blockIdx.x) * P::Nblk; - this->y = yBase + this->yrowid + this->srowid * this->ldb; - } - this->xrowid += stride; - } - - DI void ldgNextGridStride(IdxT gridStrideX, IdxT gridStrideY) + DI void ldgNextGridStride(IdxT tile_idx_m, IdxT tile_idx_n) { // Fetch next grid stride ldg if within range - if ((gridStrideX + gridDim.x * P::Nblk) < this->n) { - updateIndicesY(); - this->ldgXY(0); - } else if ((gridStrideY + gridDim.y * P::Mblk) < this->m) { - updateIndicesXY(); - this->ldgXY(0); + const auto next_tile_tile_idx_n = tile_idx_n + grid_stride_n; + const auto next_tile_tile_idx_m = tile_idx_m + grid_stride_m; + if ((next_tile_tile_idx_n) < this->n) { + this->ldgXY(tile_idx_m, next_tile_tile_idx_n, 0); + } else if ((next_tile_tile_idx_m) < this->m) { + this->ldgXY(next_tile_tile_idx_m, grid_offset_n, 0); } } - DI void prolog(IdxT gridStrideX, IdxT gridStrideY) + DI void reset_accumulator() { - if (gridStrideX == blockIdx.x * P::Nblk) { this->ldgXY(0); } - + // Reset accumulator registers to zero. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { #pragma unroll @@ -184,28 +199,6 @@ struct PairwiseDistances : public BaseClass { acc[i][j] = BaseClass::Zero; } } - - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - } - - DI void loop() - { - for (int kidx = P::Kblk; kidx < this->k; kidx += P::Kblk) { - this->ldgXY(kidx); - accumulate(); // on the previous k-block - this->stsXY(); - __syncthreads(); - this->pageWr ^= 1; - this->pageRd ^= 1; - } - accumulate(); // last iteration - // This is needed for making sure next grid stride of - // non-norm based metrics uses previously accumulated buffer so - // it doesn't make shmem dirty until previous iteration - // is complete. - this->pageRd ^= 1; } DI void accumulate() @@ -226,60 +219,52 @@ struct PairwiseDistances : public BaseClass { } } - DI void epilog(IdxT gridStrideX, IdxT gridStrideY) + DI void load_norms(IdxT tile_idx_m, + IdxT tile_idx_n, + DataT (®xn)[P::AccRowsPerTh], + DataT (®yn)[P::AccColsPerTh]) { - if (useNorms) { - DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); - DataT* syNorm = (&sxNorm[P::Mblk]); - - // Load x & y norms required by this threadblock in shmem buffer - if (gridStrideX == blockIdx.x * P::Nblk) { - for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { - auto idx = gridStrideY + i; - sxNorm[i] = idx < this->m ? xn[idx] : 0; - } - } - - for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { - auto idx = gridStrideX + i; - syNorm[i] = idx < this->n ? yn[idx] : 0; + DataT* sxNorm = (DataT*)(&smem[P::SmemSize]); + DataT* syNorm = (&sxNorm[P::Mblk]); + + // Load x & y norms required by this threadblock in shmem buffer + if (tile_idx_n == blockIdx.x * P::Nblk) { + for (int i = threadIdx.x; i < P::Mblk; i += P::Nthreads) { + auto idx = tile_idx_m + i; + sxNorm[i] = idx < this->m ? xn[idx] : 0; } + } - __syncthreads(); + for (int i = threadIdx.x; i < P::Nblk; i += P::Nthreads) { + auto idx = tile_idx_n + i; + syNorm[i] = idx < this->n ? yn[idx] : 0; + } + __syncthreads(); - DataT regxn[P::AccRowsPerTh], regyn[P::AccColsPerTh]; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; - } + for (int i = 0; i < P::AccRowsPerTh; ++i) { + regxn[i] = sxNorm[i * P::AccThRows + (threadIdx.x / P::AccThCols)]; + } #pragma unroll - for (int i = 0; i < P::AccColsPerTh; ++i) { - regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; - } - - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, regxn, regyn, gridStrideX, gridStrideY); - } else { - // Overlap ldg with epilog computation - ldgNextGridStride(gridStrideX, gridStrideY); - epilog_op(acc, nullptr, nullptr, gridStrideX, gridStrideY); + for (int i = 0; i < P::AccColsPerTh; ++i) { + regyn[i] = syNorm[i * P::AccThCols + (threadIdx.x % P::AccThCols)]; } + } - if (writeOut) { - IdxT starty = gridStrideY + this->accrowid; - IdxT startx = gridStrideX + this->acccolid; + DI void store_output(IdxT tile_idx_m, IdxT tile_idx_n) + { + IdxT starty = tile_idx_m + this->accrowid; + IdxT startx = tile_idx_n + this->acccolid; #pragma unroll - for (int i = 0; i < P::AccRowsPerTh; ++i) { - auto rowId = starty + i * P::AccThRows; + for (int i = 0; i < P::AccRowsPerTh; ++i) { + auto rowId = starty + i * P::AccThRows; #pragma unroll - for (int j = 0; j < P::AccColsPerTh; ++j) { - auto colId = startx + j * P::AccThCols; - if (rowId < this->m && colId < this->n) { - // Promote to 64 bit index for final write, as output array can be > 2^31 - dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); - } + for (int j = 0; j < P::AccColsPerTh; ++j) { + auto colId = startx + j * P::AccThCols; + if (rowId < this->m && colId < this->n) { + // Promote to 64 bit index for final write, as output array can be > 2^31 + dOutput[std::size_t(rowId) * this->n + colId] = fin_op(acc[i][j], 0); } } } diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh index 1c069fc397..93a5ce7f1a 100644 --- a/cpp/include/raft/distance/distance.cuh +++ b/cpp/include/raft/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #pragma once -#include +#include #include #include #include @@ -238,7 +238,7 @@ void distance(const InType* x, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(const raft::handle_t& handle, +void pairwise_distance(raft::device_resources const& handle, const Type* x, const Type* y, Type* dist, @@ -333,7 +333,7 @@ void pairwise_distance(const raft::handle_t& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(const raft::handle_t& handle, +void pairwise_distance(raft::device_resources const& handle, const Type* x, const Type* y, Type* dist, @@ -363,12 +363,12 @@ void pairwise_distance(const raft::handle_t& handle, * * Usage example: * @code{.cpp} - * #include + * #include * #include * #include * #include * - * raft::handle_t handle; + * raft::raft::device_resources handle; * int n_samples = 5000; * int n_features = 50; * @@ -398,7 +398,7 @@ template -void distance(raft::handle_t const& handle, +void distance(raft::device_resources const& handle, raft::device_matrix_view const x, raft::device_matrix_view const y, raft::device_matrix_view dist, @@ -441,7 +441,7 @@ void distance(raft::handle_t const& handle, * @param metric_arg metric argument (used for Minkowski distance) */ template -void pairwise_distance(raft::handle_t const& handle, +void pairwise_distance(raft::device_resources const& handle, device_matrix_view const x, device_matrix_view const y, device_matrix_view dist, diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index 1f677e919d..e832bcb020 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include #include -#include +#include #include #include #include @@ -50,7 +50,8 @@ using MinReduceOp = detail::MinReduceOpImpl; * Initialize array using init value from reduction op */ template -void initialize(const raft::handle_t& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +void initialize( + raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) { detail::initialize(min, m, maxVal, redOp, handle.get_stream()); } diff --git a/cpp/include/raft/distance/masked_nn.cuh b/cpp/include/raft/distance/masked_nn.cuh new file mode 100644 index 0000000000..ea2e10a304 --- /dev/null +++ b/cpp/include/raft/distance/masked_nn.cuh @@ -0,0 +1,199 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __MASKED_L2_NN_H +#define __MASKED_L2_NN_H + +#pragma once + +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace distance { +/** + * \defgroup masked_nn Masked 1-nearest neighbors + * @{ + */ + +/** + * @brief Parameter struct for maskedL2NN function + * + * @tparam ReduceOpT Type of reduction operator in the epilogue. + * @tparam KVPReduceOpT Type of Reduction operation on key value pairs. + * + * Usage example: + * @code{.cpp} + * #include + * + * using IdxT = int; + * using DataT = float; + * using RedOpT = raft::distance::MinAndDistanceReduceOp; + * using PairRedOpT = raft::distance::KVPMinReduce; + * using ParamT = raft::distance::MaskedL2NNParams; + * + * bool init_out = true; + * bool sqrt = false; + * + * ParamT masked_l2_params{RedOpT{}, PairRedOpT{}, sqrt, init_out}; + * @endcode + * + * Prescribes how to reduce a distance to an intermediate type (`redOp`), and + * how to reduce two intermediate types (`pairRedOp`). Typically, a distance is + * mapped to an (index, value) pair and (index, value) pair with the lowest + * value (distance) is selected. + * + * In addition, prescribes whether to compute the square root of the distance + * (`sqrt`) and whether to initialize the output buffer (`initOutBuffer`). + */ +template +struct MaskedL2NNParams { + /** Reduction operator in the epilogue */ + ReduceOpT redOp; + /** Reduction operation on key value pairs */ + KVPReduceOpT pairRedOp; + /** Whether the output `minDist` should contain L2-sqrt */ + bool sqrt; + /** Whether to initialize the output buffer before the main kernel launch */ + bool initOutBuffer; +}; + +/** + * @brief Masked L2 distance and 1-nearest-neighbor computation in a single call. + * + * This function enables faster computation of nearest neighbors if the + * computation of distances between certain point pairs can be skipped. + * + * We use an adjacency matrix that describes which distances to calculate. The + * points in `y` are divided into groups, and the adjacency matrix indicates + * whether to compute distances between points in `x` and groups in `y`. In other + * words, if `adj[i,k]` is true then distance between point `x_i`, and points in + * `group_k` will be calculated. + * + * **Performance considerations** + * + * The points in `x` are processed in tiles of `M` points (`M` is currently 64, + * but may change in the future). As a result, the largest compute time + * reduction occurs if all `M` points can skip a group. If only part of the `M` + * points can skip a group, then at most a minor compute time reduction and a + * modest energy use reduction can be expected. + * + * The points in `y` are also grouped into tiles of `N` points (`N` is currently + * 64, but may change in the future). As a result, group sizes should be larger + * than `N` to avoid wasting computational resources. If the group sizes are + * evenly divisible by `N`, then the computation is most efficient, although for + * larger group sizes this effect is minor. + * + * + * **Comparison to SDDM** + * + * [SDDMM](https://ieeexplore.ieee.org/document/8638042) (sampled dense-dense + * matrix multiplication) is a matrix-matrix multiplication where only part of + * the output is computed. Compared to maskedL2NN, there are a few differences: + * + * - The output of maskedL2NN is a single vector (of nearest neighbors) and not + * a sparse matrix. + * + * - The sampling in maskedL2NN is expressed through intermediate "groups" + rather than a CSR format. + * + * @tparam DataT data type + * @tparam OutT output type to either store 1-NN indices and their minimum + * distances or store only the min distances. Accordingly, one + * has to pass an appropriate `ReduceOpT` + * @tparam IdxT indexing arithmetic type + * @tparam ReduceOpT A struct to perform the final needed reduction operation + * and also to initialize the output array elements with the + * appropriate initial value needed for reduction. + * + * @param handle RAFT handle for managing expensive resources + * @param params Parameter struct specifying the reduction operations. + * @param[in] x First matrix. Row major. Dim = `m x k`. + * (on device). + * @param[in] y Second matrix. Row major. Dim = `n x k`. + * (on device). + * @param[in] x_norm L2 squared norm of `x`. Length = `m`. (on device). + * @param[in] y_norm L2 squared norm of `y`. Length = `n`. (on device) + * @param[in] adj A boolean adjacency matrix indicating for each + * row of `x` and each group in `y` whether to compute the + * distance. Dim = `m x num_groups`. + * @param[in] group_idxs An array containing the *end* indices of each group + * in `y`. The value of group_idxs[j] indicates the + * start of group j + 1, i.e., it is the inclusive + * scan of the group lengths. The first group is + * always assumed to start at index 0 and the last + * group typically ends at index `n`. Length = + * `num_groups`. + * @param[out] out will contain the reduced output (Length = `m`) + * (on device) + */ +template +void maskedL2NN(raft::device_resources const& handle, + raft::distance::MaskedL2NNParams params, + raft::device_matrix_view x, + raft::device_matrix_view y, + raft::device_vector_view x_norm, + raft::device_vector_view y_norm, + raft::device_matrix_view adj, + raft::device_vector_view group_idxs, + raft::device_vector_view out) +{ + IdxT m = x.extent(0); + IdxT n = y.extent(0); + IdxT k = x.extent(1); + IdxT num_groups = group_idxs.extent(0); + + // Match k dimension of x, y + RAFT_EXPECTS(x.extent(1) == y.extent(1), "Dimension of vectors in x and y must be equal."); + // Match x, x_norm and y, y_norm + RAFT_EXPECTS(m == x_norm.extent(0), "Length of `x_norm` must match input `x`."); + RAFT_EXPECTS(n == y_norm.extent(0), "Length of `y_norm` must match input `y` "); + // Match adj to x and group_idxs + RAFT_EXPECTS(m == adj.extent(0), "#rows in `adj` must match input `x`."); + RAFT_EXPECTS(num_groups == adj.extent(1), "#cols in `adj` must match length of `group_idxs`."); + // NOTE: We do not check if all indices in group_idxs actually points *inside* y. + + // If there is no work to be done, return immediately. + if (m == 0 || n == 0 || k == 0 || num_groups == 0) { return; } + + detail::maskedL2NNImpl(handle, + out.data_handle(), + x.data_handle(), + y.data_handle(), + x_norm.data_handle(), + y_norm.data_handle(), + adj.data_handle(), + group_idxs.data_handle(), + num_groups, + m, + n, + k, + params.redOp, + params.pairRedOp, + params.sqrt, + params.initOutBuffer); +} + +/** @} */ + +} // namespace distance +} // namespace raft + +#endif diff --git a/cpp/include/raft/handle.hpp b/cpp/include/raft/handle.hpp index 4525af49d2..caa68061db 100644 --- a/cpp/include/raft/handle.hpp +++ b/cpp/include/raft/handle.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,4 +21,4 @@ #pragma once -#include +#include diff --git a/cpp/include/raft/linalg/add.cuh b/cpp/include/raft/linalg/add.cuh index 27ab24abe8..608c63e1a9 100644 --- a/cpp/include/raft/linalg/add.cuh +++ b/cpp/include/raft/linalg/add.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -95,7 +95,7 @@ void addDevScalar( * @brief Elementwise add operation * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -104,7 +104,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add(const raft::handle_t& handle, InType in1, InType in2, OutType out) +void add(raft::device_resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -135,7 +135,7 @@ void add(const raft::handle_t& handle, InType in1, InType in2, OutType out) * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[in] scalar raft::device_scalar_view * @param[in] out Output @@ -145,7 +145,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add_scalar(const raft::handle_t& handle, +void add_scalar(raft::device_resources const& handle, InType in, OutType out, raft::device_scalar_view scalar) @@ -177,7 +177,7 @@ void add_scalar(const raft::handle_t& handle, * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[in] scalar raft::host_scalar_view * @param[in] out Output @@ -187,7 +187,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void add_scalar(const raft::handle_t& handle, +void add_scalar(raft::device_resources const& handle, const InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/axpy.cuh b/cpp/include/raft/linalg/axpy.cuh index 35a34bc2b5..9b3af73234 100644 --- a/cpp/include/raft/linalg/axpy.cuh +++ b/cpp/include/raft/linalg/axpy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ namespace raft::linalg { * @param [in] stream */ template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, const int n, const T* alpha, const T* x, @@ -62,7 +62,7 @@ void axpy(const raft::handle_t& handle, * @brief axpy function * It computes the following equation: y = alpha * x + y * - * @param [in] handle raft::handle_t + * @param [in] handle raft::device_resources * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector @@ -72,7 +72,7 @@ template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, raft::device_scalar_view alpha, raft::device_vector_view x, raft::device_vector_view y) @@ -92,7 +92,7 @@ void axpy(const raft::handle_t& handle, /** * @brief axpy function * It computes the following equation: y = alpha * x + y - * @param [in] handle raft::handle_t + * @param [in] handle raft::device_resources * @param [in] alpha raft::device_scalar_view * @param [in] x Input vector * @param [inout] y Output vector @@ -102,7 +102,7 @@ template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, raft::host_scalar_view alpha, raft::device_vector_view x, raft::device_vector_view y) diff --git a/cpp/include/raft/linalg/binary_op.cuh b/cpp/include/raft/linalg/binary_op.cuh index 693ef961c2..966e84965d 100644 --- a/cpp/include/raft/linalg/binary_op.cuh +++ b/cpp/include/raft/linalg/binary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/binary_op.cuh" #include -#include +#include #include #include @@ -65,7 +65,7 @@ void binaryOp( * @tparam InType Input Type raft::device_mdspan * @tparam Lambda the device-lambda performing the actual operation * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First input * @param[in] in2 Second input * @param[out] out Output @@ -78,7 +78,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void binary_op(const raft::handle_t& handle, InType in1, InType in2, OutType out, Lambda op) +void binary_op(raft::device_resources const& handle, InType in1, InType in2, OutType out, Lambda op) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); diff --git a/cpp/include/raft/linalg/cholesky_r1_update.cuh b/cpp/include/raft/linalg/cholesky_r1_update.cuh index af8d12d873..e10f43653b 100644 --- a/cpp/include/raft/linalg/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/cholesky_r1_update.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -121,7 +121,7 @@ namespace linalg { * conditioned systems. Negative values mean no regularizaton. */ template -void choleskyRank1Update(const raft::handle_t& handle, +void choleskyRank1Update(raft::device_resources const& handle, math_t* L, int n, int ld, diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 45cd640edc..674be207d8 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/coalesced_reduction.cuh" #include -#include +#include #include namespace raft { @@ -101,7 +101,7 @@ void coalescedReduction(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *

OutType (*FinalLambda)(OutType);
- * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -117,7 +117,7 @@ template -void coalesced_reduction(const raft::handle_t& handle, +void coalesced_reduction(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, diff --git a/cpp/include/raft/linalg/detail/axpy.cuh b/cpp/include/raft/linalg/detail/axpy.cuh index f3e1a177c8..5747e840c4 100644 --- a/cpp/include/raft/linalg/detail/axpy.cuh +++ b/cpp/include/raft/linalg/detail/axpy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,12 +20,12 @@ #include "cublas_wrappers.hpp" -#include +#include namespace raft::linalg::detail { template -void axpy(const raft::handle_t& handle, +void axpy(raft::device_resources const& handle, const int n, const T* alpha, const T* x, diff --git a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh index 47937815bd..afa9155753 100644 --- a/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh +++ b/cpp/include/raft/linalg/detail/cholesky_r1_update.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #include "cublas_wrappers.hpp" #include "cusolver_wrappers.hpp" -#include +#include #include namespace raft { @@ -26,7 +26,7 @@ namespace linalg { namespace detail { template -void choleskyRank1Update(const raft::handle_t& handle, +void choleskyRank1Update(raft::device_resources const& handle, math_t* L, int n, int ld, diff --git a/cpp/include/raft/linalg/detail/contractions.cuh b/cpp/include/raft/linalg/detail/contractions.cuh index 5d83f88e71..b15cb222b4 100644 --- a/cpp/include/raft/linalg/detail/contractions.cuh +++ b/cpp/include/raft/linalg/detail/contractions.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,14 +40,10 @@ struct Contractions_NT { /** leading dimension in Output D */ IdxT ldd; - /** current thread's global mem row id for X data */ - IdxT xrowid; - /** current thread's global mem row id for Y data */ - IdxT yrowid; /** global memory pointer to X matrix */ - const DataT* x; + const DataT* x_base; /** global memory pointer to Y matrix */ - const DataT* y; + const DataT* y_base; /** current thread's smem row id */ int srowid; @@ -94,10 +90,8 @@ struct Contractions_NT { k(_k), lda(_k), ldb(_k), - xrowid(IdxT(blockIdx.x) * P::Mblk + threadIdx.x / P::LdgThRow), - yrowid(IdxT(blockIdx.y) * P::Nblk + threadIdx.x / P::LdgThRow), - x(_x + xrowid * lda), - y(_y + yrowid * ldb), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -133,6 +127,8 @@ struct Contractions_NT { lda(_lda), ldb(_ldb), ldd(_ldd), + x_base(_x), + y_base(_y), srowid(threadIdx.x / P::LdgThRow), scolid((threadIdx.x % P::LdgThRow) * P::Veclen), accrowid(threadIdx.x / P::AccThCols), @@ -142,17 +138,6 @@ struct Contractions_NT { pageWr(0), pageRd(0) { - if (isRowMajor) { - xrowid = IdxT(blockIdx.y) * P::Mblk + srowid; - yrowid = IdxT(blockIdx.x) * P::Nblk + srowid; - x = _x + xrowid * lda; - y = _y + yrowid * ldb; - } else { - xrowid = IdxT(blockIdx.y) * P::Mblk; - yrowid = IdxT(blockIdx.x) * P::Nblk; - x = _x + xrowid + srowid * lda; - y = _y + yrowid + srowid * ldb; - } } protected: @@ -160,10 +145,16 @@ struct Contractions_NT { * @brief Load current block of X/Y from global memory to registers * @param[in] kidx current start index of k to be loaded */ - DI void ldgXY(IdxT kidx) + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx) + { + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx); + } + + DI void ldgXY(IdxT tile_idx_m, IdxT tile_idx_n, IdxT kidx, IdxT tile_end_n) { - ldgX(kidx); - ldgY(kidx); + ldgX(tile_idx_m, kidx); + ldgY(tile_idx_n, kidx, tile_end_n); } /** @@ -186,9 +177,16 @@ struct Contractions_NT { ldsY(kidx, sy + pageRd * P::SmemPage); } + DI void switch_read_buffer() { this->pageRd ^= 1; } + + DI void switch_write_buffer() { this->pageWr ^= 1; } + private: - DI void ldgX(IdxT kidx) + DI void ldgX(IdxT tile_idx_m, IdxT kidx) { + IdxT xrowid = isRowMajor ? tile_idx_m + srowid : tile_idx_m; + auto x = isRowMajor ? x_base + xrowid * lda : x_base + xrowid + srowid * lda; + if (isRowMajor) { auto numRows = m; auto koffset = kidx + scolid; @@ -220,10 +218,15 @@ struct Contractions_NT { } } - DI void ldgY(IdxT kidx) + DI void ldgY(IdxT tile_idx_n, IdxT kidx) { ldgY(tile_idx_n, kidx, n); } + + DI void ldgY(IdxT tile_idx_n, IdxT kidx, IdxT end_n) { + IdxT yrowid = isRowMajor ? tile_idx_n + srowid : tile_idx_n; + auto y = isRowMajor ? y_base + yrowid * ldb : y_base + yrowid + srowid * ldb; + if (isRowMajor) { - auto numRows = n; + auto numRows = end_n; auto koffset = kidx + scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { @@ -241,7 +244,7 @@ struct Contractions_NT { auto koffset = scolid; #pragma unroll for (int i = 0; i < P::LdgPerThY; ++i) { - if ((koffset + yrowid) < ldb && (srowid + kidx + i * P::LdgRowsY) < numRows) { + if ((koffset + yrowid) < end_n && (srowid + kidx + i * P::LdgRowsY) < numRows) { ldg(ldgDataY[i], y + (kidx + i * P::LdgRowsY) * ldb + koffset); } else { #pragma unroll @@ -315,4 +318,4 @@ struct Contractions_NT { } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/eig.cuh b/cpp/include/raft/linalg/detail/eig.cuh index d48b42fc57..94493efb24 100644 --- a/cpp/include/raft/linalg/detail/eig.cuh +++ b/cpp/include/raft/linalg/detail/eig.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #include "cusolver_wrappers.hpp" #include -#include +#include #include #include #include @@ -29,7 +29,7 @@ namespace linalg { namespace detail { template -void eigDC_legacy(const raft::handle_t& handle, +void eigDC_legacy(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -74,7 +74,7 @@ void eigDC_legacy(const raft::handle_t& handle, } template -void eigDC(const raft::handle_t& handle, +void eigDC(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -137,7 +137,7 @@ void eigDC(const raft::handle_t& handle, enum EigVecMemUsage { OVERWRITE_INPUT, COPY_INPUT }; template -void eigSelDC(const raft::handle_t& handle, +void eigSelDC(raft::device_resources const& handle, math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -228,7 +228,7 @@ void eigSelDC(const raft::handle_t& handle, } template -void eigJacobi(const raft::handle_t& handle, +void eigJacobi(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, diff --git a/cpp/include/raft/linalg/detail/gemm.hpp b/cpp/include/raft/linalg/detail/gemm.hpp index baa066984b..ba9496c3b9 100644 --- a/cpp/include/raft/linalg/detail/gemm.hpp +++ b/cpp/include/raft/linalg/detail/gemm.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,7 +20,7 @@ #include "cublas_wrappers.hpp" -#include +#include namespace raft { namespace linalg { @@ -49,7 +49,7 @@ namespace detail { * @param [in] stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -103,7 +103,7 @@ void gemm(const raft::handle_t& handle, * @param stream cuda stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -130,7 +130,7 @@ void gemm(const raft::handle_t& handle, } template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -149,7 +149,7 @@ void gemm(const raft::handle_t& handle, } template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, T* z, T* x, T* y, diff --git a/cpp/include/raft/linalg/detail/gemv.hpp b/cpp/include/raft/linalg/detail/gemv.hpp index 38fcdcd82e..b3e001a851 100644 --- a/cpp/include/raft/linalg/detail/gemv.hpp +++ b/cpp/include/raft/linalg/detail/gemv.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,14 +20,14 @@ #include "cublas_wrappers.hpp" -#include +#include namespace raft { namespace linalg { namespace detail { template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const bool trans_a, const int m, const int n, @@ -59,7 +59,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -76,7 +76,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -91,7 +91,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -107,7 +107,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -126,7 +126,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, diff --git a/cpp/include/raft/linalg/detail/lanczos.cuh b/cpp/include/raft/linalg/detail/lanczos.cuh index 5a3c595512..8c0cfeba28 100644 --- a/cpp/include/raft/linalg/detail/lanczos.cuh +++ b/cpp/include/raft/linalg/detail/lanczos.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -26,7 +26,7 @@ #include #include "cublas_wrappers.hpp" -#include +#include #include #include #include @@ -82,7 +82,7 @@ inline curandStatus_t curandGenerateNormalX( * @return Zero if successful. Otherwise non-zero. */ template -int performLanczosIteration(handle_t const& handle, +int performLanczosIteration(raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t* iter, index_type_t maxIter, @@ -540,7 +540,7 @@ static int francisQRIteration(index_type_t n, * @return error flag. */ template -static int lanczosRestart(handle_t const& handle, +static int lanczosRestart(raft::device_resources const& handle, index_type_t n, index_type_t iter, index_type_t iter_new, @@ -743,7 +743,7 @@ static int lanczosRestart(handle_t const& handle, */ template int computeSmallestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t nEigVecs, index_type_t maxIter, @@ -984,7 +984,7 @@ int computeSmallestEigenvectors( template int computeSmallestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const& A, index_type_t nEigVecs, index_type_t maxIter, @@ -1087,7 +1087,7 @@ int computeSmallestEigenvectors( */ template int computeLargestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const* A, index_type_t nEigVecs, index_type_t maxIter, @@ -1331,7 +1331,7 @@ int computeLargestEigenvectors( template int computeLargestEigenvectors( - handle_t const& handle, + raft::device_resources const& handle, spectral::matrix::sparse_matrix_t const& A, index_type_t nEigVecs, index_type_t maxIter, diff --git a/cpp/include/raft/linalg/detail/lstsq.cuh b/cpp/include/raft/linalg/detail/lstsq.cuh index 1273956b21..207bcefc32 100644 --- a/cpp/include/raft/linalg/detail/lstsq.cuh +++ b/cpp/include/raft/linalg/detail/lstsq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -104,7 +104,7 @@ struct DivideByNonZero { operator()(const math_t a, const math_t b) const { - return raft::myAbs(b) >= eps ? a / b : a; + return raft::abs(b) >= eps ? a / b : a; } }; @@ -117,7 +117,7 @@ struct DivideByNonZero { * so it's not guaranteed to stay unmodified. */ template -void lstsqSvdQR(const raft::handle_t& handle, +void lstsqSvdQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -177,7 +177,7 @@ void lstsqSvdQR(const raft::handle_t& handle, * so it's not guaranteed to stay unmodified. */ template -void lstsqSvdJacobi(const raft::handle_t& handle, +void lstsqSvdJacobi(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -248,7 +248,7 @@ void lstsqSvdJacobi(const raft::handle_t& handle, * (`w = (A^T A)^-1 A^T b`) */ template -void lstsqEig(const raft::handle_t& handle, +void lstsqEig(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -352,7 +352,7 @@ void lstsqEig(const raft::handle_t& handle, * Warning: the content of this vector is modified by the cuSOLVER routines. */ template -void lstsqQR(const raft::handle_t& handle, +void lstsqQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, diff --git a/cpp/include/raft/linalg/detail/map.cuh b/cpp/include/raft/linalg/detail/map.cuh index add003eb52..e0b473bdd4 100644 --- a/cpp/include/raft/linalg/detail/map.cuh +++ b/cpp/include/raft/linalg/detail/map.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/detail/map_then_reduce.cuh b/cpp/include/raft/linalg/detail/map_then_reduce.cuh index 7ef9ca1c43..70bb2df4f5 100644 --- a/cpp/include/raft/linalg/detail/map_then_reduce.cuh +++ b/cpp/include/raft/linalg/detail/map_then_reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include diff --git a/cpp/include/raft/linalg/detail/qr.cuh b/cpp/include/raft/linalg/detail/qr.cuh index 74e9c3e1aa..4cba028d87 100644 --- a/cpp/include/raft/linalg/detail/qr.cuh +++ b/cpp/include/raft/linalg/detail/qr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -42,7 +42,7 @@ namespace detail { */ template void qrGetQ_inplace( - const raft::handle_t& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) + raft::device_resources const& handle, math_t* Q, int n_rows, int n_cols, cudaStream_t stream) { RAFT_EXPECTS(n_rows >= n_cols, "QR decomposition expects n_rows >= n_cols."); cusolverDnHandle_t cusolver = handle.get_cusolver_dn_handle(); @@ -83,7 +83,7 @@ void qrGetQ_inplace( } template -void qrGetQ(const raft::handle_t& handle, +void qrGetQ(raft::device_resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -95,7 +95,7 @@ void qrGetQ(const raft::handle_t& handle, } template -void qrGetQR(const raft::handle_t& handle, +void qrGetQR(raft::device_resources const& handle, math_t* M, math_t* Q, math_t* R, diff --git a/cpp/include/raft/linalg/detail/rsvd.cuh b/cpp/include/raft/linalg/detail/rsvd.cuh index f96598d9e6..a66a23179b 100644 --- a/cpp/include/raft/linalg/detail/rsvd.cuh +++ b/cpp/include/raft/linalg/detail/rsvd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ namespace detail { * @param stream cuda stream */ template -void rsvdFixedRank(const raft::handle_t& handle, +void rsvdFixedRank(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, @@ -371,7 +371,7 @@ void rsvdFixedRank(const raft::handle_t& handle, * @param stream cuda stream */ template -void rsvdPerc(const raft::handle_t& handle, +void rsvdPerc(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, diff --git a/cpp/include/raft/linalg/detail/svd.cuh b/cpp/include/raft/linalg/detail/svd.cuh index 8626c7888b..4850744f51 100644 --- a/cpp/include/raft/linalg/detail/svd.cuh +++ b/cpp/include/raft/linalg/detail/svd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -23,7 +23,7 @@ #include #include -#include +#include #include #include #include @@ -36,7 +36,7 @@ namespace linalg { namespace detail { template -void svdQR(const raft::handle_t& handle, +void svdQR(raft::device_resources const& handle, T* in, int n_rows, int n_cols, @@ -102,7 +102,7 @@ void svdQR(const raft::handle_t& handle, } template -void svdEig(const raft::handle_t& handle, +void svdEig(raft::device_resources const& handle, math_t* in, idx_t n_rows, idx_t n_cols, @@ -162,7 +162,7 @@ void svdEig(const raft::handle_t& handle, } template -void svdJacobi(const raft::handle_t& handle, +void svdJacobi(raft::device_resources const& handle, math_t* in, int n_rows, int n_cols, @@ -232,7 +232,7 @@ void svdJacobi(const raft::handle_t& handle, } template -void svdReconstruction(const raft::handle_t& handle, +void svdReconstruction(raft::device_resources const& handle, math_t* U, math_t* S, math_t* V, @@ -263,7 +263,7 @@ void svdReconstruction(const raft::handle_t& handle, } template -bool evaluateSVDByL2Norm(const raft::handle_t& handle, +bool evaluateSVDByL2Norm(raft::device_resources const& handle, math_t* A_d, math_t* U, math_t* S_vec, diff --git a/cpp/include/raft/linalg/detail/transpose.cuh b/cpp/include/raft/linalg/detail/transpose.cuh index ef5551ea7e..9e7b236fed 100644 --- a/cpp/include/raft/linalg/detail/transpose.cuh +++ b/cpp/include/raft/linalg/detail/transpose.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,7 +19,7 @@ #include "cublas_wrappers.hpp" #include -#include +#include #include #include #include @@ -29,7 +29,7 @@ namespace linalg { namespace detail { template -void transpose(const raft::handle_t& handle, +void transpose(raft::device_resources const& handle, math_t* in, math_t* out, int n_rows, @@ -82,7 +82,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) template void transpose_row_major_impl( - handle_t const& handle, + raft::device_resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { @@ -108,7 +108,7 @@ void transpose_row_major_impl( template void transpose_col_major_impl( - handle_t const& handle, + raft::device_resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) { diff --git a/cpp/include/raft/linalg/divide.cuh b/cpp/include/raft/linalg/divide.cuh index 526d8a9716..0b18e6175c 100644 --- a/cpp/include/raft/linalg/divide.cuh +++ b/cpp/include/raft/linalg/divide.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ void divideScalar(OutT* out, const InT* in, InT scalar, IdxType len, cudaStream_ * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[in] scalar raft::host_scalar_view * @param[out] out Output @@ -66,7 +66,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void divide_scalar(const raft::handle_t& handle, +void divide_scalar(raft::device_resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/dot.cuh b/cpp/include/raft/linalg/dot.cuh index 4b1bc913e1..917188d695 100644 --- a/cpp/include/raft/linalg/dot.cuh +++ b/cpp/include/raft/linalg/dot.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include #include -#include +#include #include namespace raft::linalg { @@ -33,7 +33,7 @@ namespace raft::linalg { /** * @brief Computes the dot product of two vectors. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. @@ -43,7 +43,7 @@ template -void dot(const raft::handle_t& handle, +void dot(raft::device_resources const& handle, raft::device_vector_view x, raft::device_vector_view y, raft::device_scalar_view out) @@ -63,7 +63,7 @@ void dot(const raft::handle_t& handle, /** * @brief Computes the dot product of two vectors. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] x First input vector * @param[in] y Second input vector * @param[out] out The output dot product between the x and y vectors. @@ -73,7 +73,7 @@ template -void dot(const raft::handle_t& handle, +void dot(raft::device_resources const& handle, raft::device_vector_view x, raft::device_vector_view y, raft::host_scalar_view out) diff --git a/cpp/include/raft/linalg/eig.cuh b/cpp/include/raft/linalg/eig.cuh index 271ff13db5..03e94a10b1 100644 --- a/cpp/include/raft/linalg/eig.cuh +++ b/cpp/include/raft/linalg/eig.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -38,7 +38,7 @@ namespace linalg { * @param stream cuda stream */ template -void eigDC(const raft::handle_t& handle, +void eigDC(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -68,7 +68,7 @@ using detail::OVERWRITE_INPUT; * @param stream cuda stream */ template -void eigSelDC(const raft::handle_t& handle, +void eigSelDC(raft::device_resources const& handle, math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -97,7 +97,7 @@ void eigSelDC(const raft::handle_t& handle, * accuracy. */ template -void eigJacobi(const raft::handle_t& handle, +void eigJacobi(raft::device_resources const& handle, const math_t* in, std::size_t n_rows, std::size_t n_cols, @@ -120,14 +120,14 @@ void eigJacobi(const raft::handle_t& handle, * symmetric matrices * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view * @param[out] eig_vals: eigen values output of type raft::device_vector_view */ template -void eig_dc(const raft::handle_t& handle, +void eig_dc(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals) @@ -149,7 +149,7 @@ void eig_dc(const raft::handle_t& handle, * for the column-major symmetric matrices * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view @@ -158,7 +158,7 @@ void eig_dc(const raft::handle_t& handle, * @param[in] memUsage: the memory selection for eig vector output */ template -void eig_dc_selective(const raft::handle_t& handle, +void eig_dc_selective(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals, @@ -185,7 +185,7 @@ void eig_dc_selective(const raft::handle_t& handle, * column-major symmetric matrices (in parameter) * @tparam ValueType the data-type of input and output * @tparam IntegerType Integer used for addressing - * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] in input raft::device_matrix_view (symmetric matrix that has real eig values and * vectors) * @param[out] eig_vectors: eigenvectors output of type raft::device_matrix_view @@ -196,7 +196,7 @@ void eig_dc_selective(const raft::handle_t& handle, * accuracy. */ template -void eig_jacobi(const raft::handle_t& handle, +void eig_jacobi(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view eig_vectors, raft::device_vector_view eig_vals, diff --git a/cpp/include/raft/linalg/gemm.cuh b/cpp/include/raft/linalg/gemm.cuh index f2354da6c6..d5dc5ffab5 100644 --- a/cpp/include/raft/linalg/gemm.cuh +++ b/cpp/include/raft/linalg/gemm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -52,7 +52,7 @@ namespace linalg { * @param [in] stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const bool trans_a, const bool trans_b, const int m, @@ -91,7 +91,7 @@ void gemm(const raft::handle_t& handle, * @param stream cuda stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -126,7 +126,7 @@ void gemm(const raft::handle_t& handle, * @param stream cuda stream */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, const math_t* a, int n_rows_a, int n_cols_a, @@ -161,7 +161,7 @@ void gemm(const raft::handle_t& handle, * @param beta scalar */ template -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, T* z, T* x, T* y, @@ -213,7 +213,7 @@ template >, std::is_same>>>> -void gemm(const raft::handle_t& handle, +void gemm(raft::device_resources const& handle, raft::device_matrix_view x, raft::device_matrix_view y, raft::device_matrix_view z, diff --git a/cpp/include/raft/linalg/gemv.cuh b/cpp/include/raft/linalg/gemv.cuh index 8132a742f8..96846003f6 100644 --- a/cpp/include/raft/linalg/gemv.cuh +++ b/cpp/include/raft/linalg/gemv.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ namespace linalg { * @param [in] stream */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const bool trans_a, const int m, const int n, @@ -69,7 +69,7 @@ void gemv(const raft::handle_t& handle, } template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -103,7 +103,7 @@ void gemv(const raft::handle_t& handle, * @param stream stream on which this function is run */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -133,7 +133,7 @@ void gemv(const raft::handle_t& handle, * @param stream stream on which this function is run */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -165,7 +165,7 @@ void gemv(const raft::handle_t& handle, * @param stream stream on which this function is run */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -199,7 +199,7 @@ void gemv(const raft::handle_t& handle, * */ template -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, const math_t* A, const int n_rows_a, const int n_cols_a, @@ -246,7 +246,7 @@ template >, std::is_same>>>> -void gemv(const raft::handle_t& handle, +void gemv(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view x, raft::device_vector_view y, diff --git a/cpp/include/raft/linalg/lstsq.cuh b/cpp/include/raft/linalg/lstsq.cuh index 7654812886..b36a9eba96 100644 --- a/cpp/include/raft/linalg/lstsq.cuh +++ b/cpp/include/raft/linalg/lstsq.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,7 +18,7 @@ #pragma once -#include +#include #include namespace raft { namespace linalg { @@ -37,7 +37,7 @@ namespace linalg { * @param[in] stream cuda stream for ordering operations */ template -void lstsqSvdQR(const raft::handle_t& handle, +void lstsqSvdQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -62,7 +62,7 @@ void lstsqSvdQR(const raft::handle_t& handle, * @param[in] stream cuda stream for ordering operations */ template -void lstsqSvdJacobi(const raft::handle_t& handle, +void lstsqSvdJacobi(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -78,7 +78,7 @@ void lstsqSvdJacobi(const raft::handle_t& handle, * (`w = (A^T A)^-1 A^T b`) */ template -void lstsqEig(const raft::handle_t& handle, +void lstsqEig(raft::device_resources const& handle, const math_t* A, const int n_rows, const int n_cols, @@ -104,7 +104,7 @@ void lstsqEig(const raft::handle_t& handle, * @param[in] stream cuda stream for ordering operations */ template -void lstsqQR(const raft::handle_t& handle, +void lstsqQR(raft::device_resources const& handle, math_t* A, const int n_rows, const int n_cols, @@ -125,7 +125,7 @@ void lstsqQR(const raft::handle_t& handle, * Via SVD decomposition of `A = U S Vt`. * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -133,7 +133,7 @@ void lstsqQR(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_svd_qr(const raft::handle_t& handle, +void lstsq_svd_qr(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -155,7 +155,7 @@ void lstsq_svd_qr(const raft::handle_t& handle, * Via SVD decomposition of `A = U S V^T` using Jacobi iterations. * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -163,7 +163,7 @@ void lstsq_svd_qr(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_svd_jacobi(const raft::handle_t& handle, +void lstsq_svd_jacobi(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -186,7 +186,7 @@ void lstsq_svd_jacobi(const raft::handle_t& handle, * (`w = (A^T A)^-1 A^T b`) * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified by the cuSOLVER routines. * @param[inout] b input target raft::device_vector_view @@ -194,7 +194,7 @@ void lstsq_svd_jacobi(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_eig(const raft::handle_t& handle, +void lstsq_eig(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) @@ -217,7 +217,7 @@ void lstsq_eig(const raft::handle_t& handle, * (triangular system of equations `Rw = Q^T b`) * * @tparam ValueType the data-type of input/output - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[inout] A input raft::device_matrix_view * Warning: the content of this matrix is modified. * @param[inout] b input target raft::device_vector_view @@ -225,7 +225,7 @@ void lstsq_eig(const raft::handle_t& handle, * @param[out] w output coefficient raft::device_vector_view */ template -void lstsq_qr(const raft::handle_t& handle, +void lstsq_qr(raft::device_resources const& handle, raft::device_matrix_view A, raft::device_vector_view b, raft::device_vector_view w) diff --git a/cpp/include/raft/linalg/map.cuh b/cpp/include/raft/linalg/map.cuh index ad35cc5880..2b9e6c80a0 100644 --- a/cpp/include/raft/linalg/map.cuh +++ b/cpp/include/raft/linalg/map.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,9 @@ #include "detail/map.cuh" #include +#include #include +#include namespace raft { namespace linalg { @@ -65,7 +67,7 @@ void map_k( * @tparam TPB threads-per-block in the final kernel launched * @tparam OutType data-type of result of type raft::device_mdspan * @tparam Args additional parameters - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input of type raft::device_mdspan * @param[out] out the output of the map operation of type raft::device_mdspan * @param[in] map the device-lambda @@ -78,7 +80,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args... args) +void map(raft::device_resources const& handle, InType in, OutType out, MapOp map, Args... args) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -96,9 +98,43 @@ void map(const raft::handle_t& handle, InType in, OutType out, MapOp map, Args.. } } +/** + * @brief Perform an element-wise unary operation on the input offset into the output array + * + * Usage example: + * @code{.cpp} + * #include + * #include + * #include + * #include + * ... + * raft::handle_t handle; + * auto squares = raft::make_device_vector(handle, n); + * raft::linalg::map_offset(handle, squares.view(), raft::sq_op()); + * @endcode + * + * @tparam OutType Output mdspan type + * @tparam MapOp The unary operation type with signature `OutT func(const IdxT& idx);` + * @param[in] handle The raft handle + * @param[out] out Output array + * @param[in] op The unary operation + */ +template > +void map_offset(const raft::device_resources& handle, OutType out, MapOp op) +{ + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); + + using out_value_t = typename OutType::value_type; + + thrust::tabulate( + handle.get_thrust_policy(), out.data_handle(), out.data_handle() + out.size(), op); +} + /** @} */ // end of map } // namespace linalg }; // namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/linalg/map_reduce.cuh b/cpp/include/raft/linalg/map_reduce.cuh index 4158d35bca..b89f3bdd54 100644 --- a/cpp/include/raft/linalg/map_reduce.cuh +++ b/cpp/include/raft/linalg/map_reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -75,7 +75,7 @@ void mapReduce(OutType* out, * @tparam OutValueType the data-type of the output * @tparam ScalarIdxType index type of scalar * @tparam Args additional parameters - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input of type raft::device_vector_view * @param[in] neutral The neutral element of the reduction operation. For example: * 0 for sum, 1 for multiply, +Inf for Min, -Inf for Max @@ -91,7 +91,7 @@ template -void map_reduce(const raft::handle_t& handle, +void map_reduce(raft::device_resources const& handle, raft::device_vector_view in, raft::device_scalar_view out, OutValueType neutral, diff --git a/cpp/include/raft/linalg/matrix_vector.cuh b/cpp/include/raft/linalg/matrix_vector.cuh index 5529ded16f..fa24ea28b7 100644 --- a/cpp/include/raft/linalg/matrix_vector.cuh +++ b/cpp/include/raft/linalg/matrix_vector.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ namespace raft::linalg { * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_mult_skip_zero(const raft::handle_t& handle, +void binary_mult_skip_zero(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -70,7 +70,7 @@ void binary_mult_skip_zero(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_div(const raft::handle_t& handle, +void binary_div(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -105,7 +105,7 @@ void binary_div(const raft::handle_t& handle, * value if false */ template -void binary_div_skip_zero(const raft::handle_t& handle, +void binary_div_skip_zero(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply, @@ -140,7 +140,7 @@ void binary_div_skip_zero(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_add(const raft::handle_t& handle, +void binary_add(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) @@ -173,7 +173,7 @@ void binary_add(const raft::handle_t& handle, * the rows of the matrix or columns using enum class raft::linalg::Apply */ template -void binary_sub(const raft::handle_t& handle, +void binary_sub(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view vec, Apply apply) diff --git a/cpp/include/raft/linalg/matrix_vector_op.cuh b/cpp/include/raft/linalg/matrix_vector_op.cuh index 8b5163a714..59b2ca5ee5 100644 --- a/cpp/include/raft/linalg/matrix_vector_op.cuh +++ b/cpp/include/raft/linalg/matrix_vector_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -122,7 +122,7 @@ void matrixVectorOp(MatT* out, * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator * @tparam IndexType Integer used for addressing - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] matrix input raft::matrix_view * @param[in] vec vector raft::vector_view * @param[out] out output raft::matrix_view @@ -135,7 +135,7 @@ template -void matrix_vector_op(const raft::handle_t& handle, +void matrix_vector_op(raft::device_resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec, raft::device_matrix_view out, @@ -182,7 +182,7 @@ void matrix_vector_op(const raft::handle_t& handle, * @tparam LayoutPolicy the layout of input and output (raft::row_major or raft::col_major) * @tparam Lambda a device function which represents a binary operator * @tparam IndexType Integer used for addressing - * @param handle raft::handle_t + * @param handle raft::device_resources * @param matrix input raft::matrix_view * @param vec1 the first vector raft::vector_view * @param vec2 the second vector raft::vector_view @@ -197,7 +197,7 @@ template -void matrix_vector_op(const raft::handle_t& handle, +void matrix_vector_op(raft::device_resources const& handle, raft::device_matrix_view matrix, raft::device_vector_view vec1, raft::device_vector_view vec2, diff --git a/cpp/include/raft/linalg/mean_squared_error.cuh b/cpp/include/raft/linalg/mean_squared_error.cuh index a3360ae35a..62f4896d01 100644 --- a/cpp/include/raft/linalg/mean_squared_error.cuh +++ b/cpp/include/raft/linalg/mean_squared_error.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -53,14 +53,14 @@ void meanSquaredError( * @tparam IndexType Input/Output index type * @tparam OutValueType Output data-type * @tparam TPB threads-per-block - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] A input raft::device_vector_view * @param[in] B input raft::device_vector_view * @param[out] out the output mean squared error value of type raft::device_scalar_view * @param[in] weight weight to apply to every term in the mean squared error calculation */ template -void mean_squared_error(const raft::handle_t& handle, +void mean_squared_error(raft::device_resources const& handle, raft::device_vector_view A, raft::device_vector_view B, raft::device_scalar_view out, diff --git a/cpp/include/raft/linalg/multiply.cuh b/cpp/include/raft/linalg/multiply.cuh index 119cf667d1..574b88c63d 100644 --- a/cpp/include/raft/linalg/multiply.cuh +++ b/cpp/include/raft/linalg/multiply.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -56,7 +56,7 @@ void multiplyScalar(out_t* out, const in_t* in, in_t scalar, IdxType len, cudaSt * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input buffer * @param[out] out the output buffer * @param[in] scalar the scalar used in the operations @@ -68,7 +68,7 @@ template , typename = raft::enable_if_output_device_mdspan> void multiply_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/norm.cuh b/cpp/include/raft/linalg/norm.cuh index b64b128fa2..8bc6720b4e 100644 --- a/cpp/include/raft/linalg/norm.cuh +++ b/cpp/include/raft/linalg/norm.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -99,7 +99,7 @@ void colNorm(Type* dots, * @tparam LayoutPolicy the layout of input (raft::row_major or raft::col_major) * @tparam IdxType Integer type used to for addressing * @tparam Lambda device final lambda - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_vector_view * @param[in] type the type of norm to be applied @@ -111,7 +111,7 @@ template -void norm(const raft::handle_t& handle, +void norm(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view out, NormType type, diff --git a/cpp/include/raft/linalg/normalize.cuh b/cpp/include/raft/linalg/normalize.cuh index bf6ef5a570..027ebb16e8 100644 --- a/cpp/include/raft/linalg/normalize.cuh +++ b/cpp/include/raft/linalg/normalize.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -37,7 +37,7 @@ namespace linalg { * @tparam MainLambda Type of main_op * @tparam ReduceLambda Type of reduce_op * @tparam FinalLambda Type of fin_op - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_matrix_view * @param[in] init Initialization value, i.e identity element for the reduction operation @@ -52,7 +52,7 @@ template -void row_normalize(const raft::handle_t& handle, +void row_normalize(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, ElementType init, @@ -85,14 +85,14 @@ void row_normalize(const raft::handle_t& handle, * * @tparam ElementType Input/Output data type * @tparam IndexType Integer type used to for addressing - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in the input raft::device_matrix_view * @param[out] out the output raft::device_matrix_view * @param[in] norm_type the type of norm to be applied * @param[in] eps If the norm is below eps, the row is considered zero and no division is applied */ template -void row_normalize(const raft::handle_t& handle, +void row_normalize(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, NormType norm_type, diff --git a/cpp/include/raft/linalg/power.cuh b/cpp/include/raft/linalg/power.cuh index 59c2cdf314..1fdfcb3780 100644 --- a/cpp/include/raft/linalg/power.cuh +++ b/cpp/include/raft/linalg/power.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -73,7 +73,7 @@ void power(out_t* out, const in_t* in1, const in_t* in2, IdxType len, cudaStream * @brief Elementwise power operation on the input buffers * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -82,7 +82,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void power(const raft::handle_t& handle, InType in1, InType in2, OutType out) +void power(raft::device_resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -113,7 +113,7 @@ void power(const raft::handle_t& handle, InType in1, InType in2, OutType out) * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::host_scalar_view @@ -124,7 +124,7 @@ template , typename = raft::enable_if_output_device_mdspan> void power_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, const raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/qr.cuh b/cpp/include/raft/linalg/qr.cuh index 7c5c0ea628..8e58af63c1 100644 --- a/cpp/include/raft/linalg/qr.cuh +++ b/cpp/include/raft/linalg/qr.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace linalg { * @param stream cuda stream */ template -void qrGetQ(const raft::handle_t& handle, +void qrGetQ(raft::device_resources const& handle, const math_t* M, math_t* Q, int n_rows, @@ -54,7 +54,7 @@ void qrGetQ(const raft::handle_t& handle, * @param stream cuda stream */ template -void qrGetQR(const raft::handle_t& handle, +void qrGetQR(raft::device_resources const& handle, math_t* M, math_t* Q, math_t* R, @@ -72,12 +72,12 @@ void qrGetQR(const raft::handle_t& handle, /** * @brief Compute the QR decomposition of matrix M and return only the Q matrix. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M Input raft::device_matrix_view * @param[out] Q Output raft::device_matrix_view */ template -void qr_get_q(const raft::handle_t& handle, +void qr_get_q(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q) { @@ -88,13 +88,13 @@ void qr_get_q(const raft::handle_t& handle, /** * @brief Compute the QR decomposition of matrix M and return both the Q and R matrices. - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M Input raft::device_matrix_view * @param[in] Q Output raft::device_matrix_view * @param[out] R Output raft::device_matrix_view */ template -void qr_get_qr(const raft::handle_t& handle, +void qr_get_qr(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_matrix_view Q, raft::device_matrix_view R) diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 3eb8196408..ae5457c44f 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -105,7 +105,7 @@ void reduce(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -122,7 +122,7 @@ template -void reduce(const raft::handle_t& handle, +void reduce(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutElementType init, diff --git a/cpp/include/raft/linalg/reduce_cols_by_key.cuh b/cpp/include/raft/linalg/reduce_cols_by_key.cuh index 7b0ad2f984..2b744d8134 100644 --- a/cpp/include/raft/linalg/reduce_cols_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_cols_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/reduce_cols_by_key.cuh" #include -#include +#include namespace raft { namespace linalg { @@ -69,7 +69,7 @@ void reduce_cols_by_key(const T* data, * @tparam ElementType the input data type (as well as the output reduced matrix) * @tparam KeyType data type of the keys * @tparam IndexType indexing arithmetic type - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] data the input data (dim = nrows x ncols). This is assumed to be in * row-major layout of type raft::device_matrix_view * @param[in] keys keys raft::device_vector_view (len = ncols). It is assumed that each key in this @@ -84,7 +84,7 @@ void reduce_cols_by_key(const T* data, */ template void reduce_cols_by_key( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view keys, raft::device_matrix_view out, diff --git a/cpp/include/raft/linalg/reduce_rows_by_key.cuh b/cpp/include/raft/linalg/reduce_rows_by_key.cuh index 1dabd92087..484b60238b 100644 --- a/cpp/include/raft/linalg/reduce_rows_by_key.cuh +++ b/cpp/include/raft/linalg/reduce_rows_by_key.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/reduce_rows_by_key.cuh" #include -#include +#include namespace raft { namespace linalg { @@ -136,7 +136,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, * @tparam KeyType data-type of keys * @tparam WeightType data-type of weights * @tparam IndexType index type - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] d_A Input raft::device_mdspan (ncols * nrows) * @param[in] d_keys Keys for each row raft::device_vector_view (1 x nrows) * @param[out] d_sums Row sums by key raft::device_matrix_view (ncols x d_keys) @@ -148,7 +148,7 @@ void reduce_rows_by_key(const DataIteratorT d_A, */ template void reduce_rows_by_key( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view d_A, raft::device_vector_view d_keys, raft::device_matrix_view d_sums, diff --git a/cpp/include/raft/linalg/rsvd.cuh b/cpp/include/raft/linalg/rsvd.cuh index 6f0315642b..eb94547f13 100644 --- a/cpp/include/raft/linalg/rsvd.cuh +++ b/cpp/include/raft/linalg/rsvd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -47,7 +47,7 @@ namespace linalg { * @param stream cuda stream */ template -void rsvdFixedRank(const raft::handle_t& handle, +void rsvdFixedRank(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, @@ -104,7 +104,7 @@ void rsvdFixedRank(const raft::handle_t& handle, * @param stream cuda stream */ template -void rsvdPerc(const raft::handle_t& handle, +void rsvdPerc(raft::device_resources const& handle, math_t* M, int n_rows, int n_cols, @@ -154,7 +154,7 @@ void rsvdPerc(const raft::handle_t& handle, * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -164,7 +164,7 @@ void rsvdPerc(const raft::handle_t& handle, * raft::col_major */ template -void rsvd_fixed_rank(const raft::handle_t& handle, +void rsvd_fixed_rank(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -228,7 +228,7 @@ void rsvd_fixed_rank(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -239,7 +239,7 @@ void rsvd_fixed_rank(Args... args) */ template void rsvd_fixed_rank_symmetric( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -303,7 +303,7 @@ void rsvd_fixed_rank_symmetric(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -315,7 +315,7 @@ void rsvd_fixed_rank_symmetric(Args... args) * raft::col_major */ template -void rsvd_fixed_rank_jacobi(const raft::handle_t& handle, +void rsvd_fixed_rank_jacobi(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -381,7 +381,7 @@ void rsvd_fixed_rank_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] p no. of upsamples @@ -394,7 +394,7 @@ void rsvd_fixed_rank_jacobi(Args... args) */ template void rsvd_fixed_rank_symmetric_jacobi( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, IndexType p, @@ -460,7 +460,7 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -471,7 +471,7 @@ void rsvd_fixed_rank_symmetric_jacobi(Args... args) * raft::col_major */ template -void rsvd_perc(const raft::handle_t& handle, +void rsvd_perc(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -536,7 +536,7 @@ void rsvd_perc(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -547,7 +547,7 @@ void rsvd_perc(Args... args) * raft::col_major */ template -void rsvd_perc_symmetric(const raft::handle_t& handle, +void rsvd_perc_symmetric(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -612,7 +612,7 @@ void rsvd_perc_symmetric(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -625,7 +625,7 @@ void rsvd_perc_symmetric(Args... args) * raft::col_major */ template -void rsvd_perc_jacobi(const raft::handle_t& handle, +void rsvd_perc_jacobi(raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, @@ -692,7 +692,7 @@ void rsvd_perc_jacobi(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] M input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S_vec singular values raft::device_vector_view of shape (K) * @param[in] PC_perc percentage of singular values to be computed @@ -706,7 +706,7 @@ void rsvd_perc_jacobi(Args... args) */ template void rsvd_perc_symmetric_jacobi( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view M, raft::device_vector_view S_vec, ValueType PC_perc, diff --git a/cpp/include/raft/linalg/sqrt.cuh b/cpp/include/raft/linalg/sqrt.cuh index ad6cad2eb2..55e661897d 100644 --- a/cpp/include/raft/linalg/sqrt.cuh +++ b/cpp/include/raft/linalg/sqrt.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -51,7 +51,7 @@ void sqrt(out_t* out, const in_t* in, IdxType len, cudaStream_t stream) * @brief Elementwise sqrt operation * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output */ @@ -59,7 +59,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void sqrt(const raft::handle_t& handle, InType in, OutType out) +void sqrt(raft::device_resources const& handle, InType in, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index d9c26910e7..d282a2e1fa 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -112,7 +112,7 @@ void stridedReduction(OutType* dots, * @tparam FinalLambda the final lambda applied before STG (eg: Sqrt for L2 norm) * It must be a 'callable' supporting the following input and output: *
OutType (*FinalLambda)(OutType);
- * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] data Input of type raft::device_matrix_view * @param[out] dots Output of type raft::device_matrix_view * @param[in] init initial value to use for the reduction @@ -128,7 +128,7 @@ template -void strided_reduction(const raft::handle_t& handle, +void strided_reduction(raft::device_resources const& handle, raft::device_matrix_view data, raft::device_vector_view dots, OutValueType init, diff --git a/cpp/include/raft/linalg/subtract.cuh b/cpp/include/raft/linalg/subtract.cuh index e6f2fa8724..da995b7a2a 100644 --- a/cpp/include/raft/linalg/subtract.cuh +++ b/cpp/include/raft/linalg/subtract.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -97,7 +97,7 @@ void subtractDevScalar(math_t* outDev, * @brief Elementwise subtraction operation on the input buffers * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan - * @param handle raft::handle_t + * @param handle raft::device_resources * @param[in] in1 First Input * @param[in] in2 Second Input * @param[out] out Output @@ -106,7 +106,7 @@ template , typename = raft::enable_if_output_device_mdspan> -void subtract(const raft::handle_t& handle, InType in1, InType in2, OutType out) +void subtract(raft::device_resources const& handle, InType in1, InType in2, OutType out) { using in_value_t = typename InType::value_type; using out_value_t = typename OutType::value_type; @@ -137,7 +137,7 @@ void subtract(const raft::handle_t& handle, InType in1, InType in2, OutType out) * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::device_scalar_view @@ -148,7 +148,7 @@ template , typename = raft::enable_if_output_device_mdspan> void subtract_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, raft::device_scalar_view scalar) @@ -182,7 +182,7 @@ void subtract_scalar( * @tparam InType Input Type raft::device_mdspan * @tparam OutType Output Type raft::device_mdspan * @tparam ScalarIdxType Index Type of scalar - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in Input * @param[out] out Output * @param[in] scalar raft::host_scalar_view @@ -193,7 +193,7 @@ template , typename = raft::enable_if_output_device_mdspan> void subtract_scalar( - const raft::handle_t& handle, + raft::device_resources const& handle, InType in, OutType out, raft::host_scalar_view scalar) diff --git a/cpp/include/raft/linalg/svd.cuh b/cpp/include/raft/linalg/svd.cuh index 2c1b5a5cb7..eb51093240 100644 --- a/cpp/include/raft/linalg/svd.cuh +++ b/cpp/include/raft/linalg/svd.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ namespace linalg { * @param stream cuda stream */ template -void svdQR(const raft::handle_t& handle, +void svdQR(raft::device_resources const& handle, T* in, int n_rows, int n_cols, @@ -67,7 +67,7 @@ void svdQR(const raft::handle_t& handle, } template -void svdEig(const raft::handle_t& handle, +void svdEig(raft::device_resources const& handle, math_t* in, idx_t n_rows, idx_t n_cols, @@ -98,7 +98,7 @@ void svdEig(const raft::handle_t& handle, * @param stream cuda stream */ template -void svdJacobi(const raft::handle_t& handle, +void svdJacobi(raft::device_resources const& handle, math_t* in, int n_rows, int n_cols, @@ -139,7 +139,7 @@ void svdJacobi(const raft::handle_t& handle, * @param stream cuda stream */ template -void svdReconstruction(const raft::handle_t& handle, +void svdReconstruction(raft::device_resources const& handle, math_t* U, math_t* S, math_t* V, @@ -167,7 +167,7 @@ void svdReconstruction(const raft::handle_t& handle, * @param stream cuda stream */ template -bool evaluateSVDByL2Norm(const raft::handle_t& handle, +bool evaluateSVDByL2Norm(raft::device_resources const& handle, math_t* A_d, math_t* U, math_t* S_vec, @@ -195,7 +195,7 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle, * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) * @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout @@ -204,7 +204,7 @@ bool evaluateSVDByL2Norm(const raft::handle_t& handle, * layout raft::col_major and dimensions (n, n) */ template -void svd_qr(const raft::handle_t& handle, +void svd_qr(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, UType&& U_in, @@ -258,7 +258,7 @@ void svd_qr(Args... args) * U_in * @tparam VType std::optional> @c * V_in - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] sing_vals singular values raft::device_vector_view of shape (K) * @param[out] U_in std::optional left singular values of raft::device_matrix_view with layout @@ -268,7 +268,7 @@ void svd_qr(Args... args) */ template void svd_qr_transpose_right_vec( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view sing_vals, UType&& U_in, @@ -316,7 +316,7 @@ void svd_qr_transpose_right_vec(Args... args) /** * @brief singular value decomposition (SVD) on a column major * matrix using Eigen decomposition. A square symmetric covariance matrix is constructed for the SVD - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in input raft::device_matrix_view with layout raft::col_major of shape (M, N) * @param[out] S singular values raft::device_vector_view of shape (K) * @param[out] V right singular values of raft::device_matrix_view with layout @@ -326,7 +326,7 @@ void svd_qr_transpose_right_vec(Args... args) */ template void svd_eig( - const raft::handle_t& handle, + raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view S, raft::device_matrix_view V, @@ -352,7 +352,7 @@ void svd_eig( /** * @brief reconstruct a matrix use left and right singular vectors and * singular values - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] U left singular values of raft::device_matrix_view with layout * raft::col_major and dimensions (m, k) * @param[in] S singular values raft::device_vector_view of shape (k, k) @@ -361,7 +361,7 @@ void svd_eig( * @param[out] out output raft::device_matrix_view with layout raft::col_major of shape (m, n) */ template -void svd_reconstruction(const raft::handle_t& handle, +void svd_reconstruction(raft::device_resources const& handle, raft::device_matrix_view U, raft::device_vector_view S, raft::device_matrix_view V, diff --git a/cpp/include/raft/linalg/ternary_op.cuh b/cpp/include/raft/linalg/ternary_op.cuh index 10e91a0313..aa3859bc23 100644 --- a/cpp/include/raft/linalg/ternary_op.cuh +++ b/cpp/include/raft/linalg/ternary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2018-2022, NVIDIA CORPORATION. + * Copyright (c) 2018-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,7 +22,7 @@ #include "detail/ternary_op.cuh" #include -#include +#include #include namespace raft { @@ -63,7 +63,7 @@ void ternaryOp(out_t* out, * @tparam InType Input Type raft::device_mdspan * @tparam Lambda the device-lambda performing the actual operation * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t + * @param[in] handle raft::device_resources * @param[in] in1 First input * @param[in] in2 Second input * @param[in] in3 Third input @@ -78,7 +78,7 @@ template , typename = raft::enable_if_output_device_mdspan> void ternary_op( - const raft::handle_t& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) + raft::device_resources const& handle, InType in1, InType in2, InType in3, OutType out, Lambda op) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(in1), "Input 1 must be contiguous"); diff --git a/cpp/include/raft/linalg/transpose.cuh b/cpp/include/raft/linalg/transpose.cuh index 608a87b489..a0f418b4f7 100644 --- a/cpp/include/raft/linalg/transpose.cuh +++ b/cpp/include/raft/linalg/transpose.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -34,7 +34,7 @@ namespace linalg { * @param stream: cuda stream */ template -void transpose(const raft::handle_t& handle, +void transpose(raft::device_resources const& handle, math_t* in, math_t* out, int n_rows, @@ -76,7 +76,7 @@ void transpose(math_t* inout, int n, cudaStream_t stream) * @param[out] out Output matirx, storage is pre-allocated by caller. */ template -auto transpose(handle_t const& handle, +auto transpose(raft::device_resources const& handle, raft::mdspan, LayoutPolicy, AccessorPolicy> in, raft::mdspan, LayoutPolicy, AccessorPolicy> out) -> std::enable_if_t, void> diff --git a/cpp/include/raft/linalg/unary_op.cuh b/cpp/include/raft/linalg/unary_op.cuh index a90bda06d5..ce102adfd2 100644 --- a/cpp/include/raft/linalg/unary_op.cuh +++ b/cpp/include/raft/linalg/unary_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,7 +21,7 @@ #include "detail/unary_op.cuh" #include -#include +#include #include namespace raft { @@ -30,17 +30,16 @@ namespace linalg { /** * @brief perform element-wise unary operation in the input array * @tparam InType input data-type - * @tparam Lambda the device-lambda performing the actual operation + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `OutType func(const InType& val);` * @tparam OutType output data-type * @tparam IdxType Integer type used to for addressing * @tparam TPB threads-per-block in the final kernel launched - * @param out the output array - * @param in the input array - * @param len number of elements in the input array - * @param op the device-lambda - * @param stream cuda stream where to launch work - * @note Lambda must be a functor with the following signature: - * `OutType func(const InType& val);` + * @param[out] out Output array [on device], dim = [len] + * @param[in] in Input array [on device], dim = [len] + * @param[in] len Number of elements in the input array + * @param[in] op Device lambda + * @param[in] stream cuda stream where to launch work */ template @@ -81,23 +80,22 @@ void writeOnlyUnaryOp(OutType* out, IdxType len, Lambda op, cudaStream_t stream) */ /** - * @brief perform element-wise binary operation on the input arrays + * @brief Perform an element-wise unary operation into the output array * @tparam InType Input Type raft::device_mdspan - * @tparam Lambda the device-lambda performing the actual operation + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `out_value_t func(const in_value_t& val);` * @tparam OutType Output Type raft::device_mdspan - * @param[in] handle raft::handle_t - * @param[in] in Input - * @param[out] out Output - * @param[in] op the device-lambda - * @note Lambda must be a functor with the following signature: - * `InType func(const InType& val);` + * @param[in] handle The raft handle + * @param[in] in Input + * @param[out] out Output + * @param[in] op Device lambda */ template , typename = raft::enable_if_output_device_mdspan> -void unary_op(const raft::handle_t& handle, InType in, OutType out, Lambda op) +void unary_op(raft::device_resources const& handle, InType in, OutType out, Lambda op) { RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); @@ -116,29 +114,32 @@ void unary_op(const raft::handle_t& handle, InType in, OutType out, Lambda op) } /** - * @brief perform element-wise binary operation on the input arrays - * This function does not read from the input - * @tparam InType Input Type raft::device_mdspan - * @tparam Lambda the device-lambda performing the actual operation - * @param[in] handle raft::handle_t - * @param[inout] in Input/Output - * @param[in] op the device-lambda - * @note Lambda must be a functor with the following signature: - * `InType func(const InType& val);` + * @brief Perform an element-wise unary operation on the input index into the output array + * + * @note This operation is deprecated. Please use map_offset in `raft/linalg/map.cuh` instead. + * + * @tparam OutType Output Type raft::device_mdspan + * @tparam Lambda Device lambda performing the actual operation, with the signature + * `void func(out_value_t* out_location, index_t idx);` + * @param[in] handle The raft handle + * @param[out] out Output + * @param[in] op Device lambda */ -template > -void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op) +template > +void write_only_unary_op(const raft::device_resources& handle, OutType out, Lambda op) { - RAFT_EXPECTS(raft::is_row_or_column_major(in), "Input must be contiguous"); + RAFT_EXPECTS(raft::is_row_or_column_major(out), "Output must be contiguous"); - using in_value_t = typename InType::value_type; + using out_value_t = typename OutType::value_type; - if (in.size() <= std::numeric_limits::max()) { - writeOnlyUnaryOp( - in.data_handle(), in.size(), op, handle.get_stream()); + if (out.size() <= std::numeric_limits::max()) { + writeOnlyUnaryOp( + out.data_handle(), out.size(), op, handle.get_stream()); } else { - writeOnlyUnaryOp( - in.data_handle(), in.size(), op, handle.get_stream()); + writeOnlyUnaryOp( + out.data_handle(), out.size(), op, handle.get_stream()); } } @@ -147,4 +148,4 @@ void write_only_unary_op(const raft::handle_t& handle, InType in, Lambda op) }; // end namespace linalg }; // end namespace raft -#endif \ No newline at end of file +#endif diff --git a/cpp/include/raft/matrix/argmax.cuh b/cpp/include/raft/matrix/argmax.cuh index a614f7043f..433c161079 100644 --- a/cpp/include/raft/matrix/argmax.cuh +++ b/cpp/include/raft/matrix/argmax.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace raft::matrix { * @param[out] out: output vector of size n_rows */ template -void argmax(const raft::handle_t& handle, +void argmax(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { diff --git a/cpp/include/raft/matrix/argmin.cuh b/cpp/include/raft/matrix/argmin.cuh index ca7b0252d2..31ef0c1c1b 100644 --- a/cpp/include/raft/matrix/argmin.cuh +++ b/cpp/include/raft/matrix/argmin.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -33,7 +33,7 @@ namespace raft::matrix { * @param[out] out: output vector of size n_rows */ template -void argmin(const raft::handle_t& handle, +void argmin(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_vector_view out) { diff --git a/cpp/include/raft/matrix/col_wise_sort.cuh b/cpp/include/raft/matrix/col_wise_sort.cuh index 662f62d865..a4daf097e5 100644 --- a/cpp/include/raft/matrix/col_wise_sort.cuh +++ b/cpp/include/raft/matrix/col_wise_sort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2022, NVIDIA CORPORATION. + * Copyright (c) 2019-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -71,7 +71,7 @@ void sort_cols_per_row(const InType* in, * @param[out] sorted_keys_opt: std::optional, output matrix for sorted keys (input) */ template -void sort_cols_per_row(const raft::handle_t& handle, +void sort_cols_per_row(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, sorted_keys_t&& sorted_keys_opt) diff --git a/cpp/include/raft/matrix/copy.cuh b/cpp/include/raft/matrix/copy.cuh index 0727fac246..42d2562e5e 100644 --- a/cpp/include/raft/matrix/copy.cuh +++ b/cpp/include/raft/matrix/copy.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ namespace raft::matrix { * @param[in] indices of the rows to be copied */ template -void copy_rows(const raft::handle_t& handle, +void copy_rows(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out, raft::device_vector_view indices) @@ -65,7 +65,7 @@ void copy_rows(const raft::handle_t& handle, * @param[out] out: output matrix */ template -void copy(const raft::handle_t& handle, +void copy(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { @@ -84,7 +84,7 @@ void copy(const raft::handle_t& handle, * @param out: output matrix */ template -void trunc_zero_origin(const raft::handle_t& handle, +void trunc_zero_origin(raft::device_resources const& handle, raft::device_matrix_view in, raft::device_matrix_view out) { diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index c006f69e47..f6dc60bf85 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,41 +17,63 @@ #pragma once #include +#include namespace raft { namespace matrix { namespace detail { -// gatherKernel conditionally copies rows from the source matrix 'in' into the destination matrix -// 'out' according to a map (or a transformed map) -template +struct gather_policy { + static constexpr int n_threads = tpb; + static constexpr int work_per_thread = wpt; + static constexpr int stride = tpb * wpt; +}; + +/** Conditionally copies rows from the source matrix 'in' into the destination matrix + * 'out' according to a map (or a transformed map) */ +template -__global__ void gatherKernel(const MatrixIteratorT in, - IndexT D, - IndexT N, - MapIteratorT map, - StencilIteratorT stencil, - MatrixIteratorT out, - PredicateOp pred_op, - MapTransformOp transform_op) + typename OutputIteratorT, + typename IndexT> +__global__ void gather_kernel(const InputIteratorT in, + IndexT D, + IndexT len, + const MapIteratorT map, + StencilIteratorT stencil, + OutputIteratorT out, + PredicateOp pred_op, + MapTransformOp transform_op) { typedef typename std::iterator_traits::value_type MapValueT; typedef typename std::iterator_traits::value_type StencilValueT; - IndexT outRowStart = blockIdx.x * D; - MapValueT map_val = map[blockIdx.x]; - StencilValueT stencil_val = stencil[blockIdx.x]; +#pragma unroll + for (IndexT wid = 0; wid < Policy::work_per_thread; wid++) { + IndexT tid = threadIdx.x + (Policy::work_per_thread * static_cast(blockIdx.x) + wid) * + Policy::n_threads; + if (tid < len) { + IndexT i_dst = tid / D; + IndexT j = tid % D; + + MapValueT map_val = map[i_dst]; + StencilValueT stencil_val = stencil[i_dst]; - bool predicate = pred_op(stencil_val); - if (predicate) { - IndexT inRowStart = transform_op(map_val) * D; - for (int i = threadIdx.x; i < D; i += TPB) { - out[outRowStart + i] = in[inRowStart + i]; + bool predicate = pred_op(stencil_val); + if (predicate) { + IndexT i_src = transform_op(map_val); + out[tid] = in[i_src * D + j]; + } } } } @@ -60,7 +82,7 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @brief gather conditionally copies rows from a source matrix into a destination matrix according * to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -69,7 +91,10 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -83,18 +108,20 @@ __global__ void gatherKernel(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gatherImpl(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gatherImpl(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) @@ -102,9 +129,6 @@ void gatherImpl(const MatrixIteratorT in, // skip in case of 0 length input if (map_length <= 0 || N <= 0 || D <= 0) return; - // signed integer type for indexing or global offsets - typedef int IndexT; - // map value type typedef typename std::iterator_traits::value_type MapValueT; @@ -121,38 +145,26 @@ void gatherImpl(const MatrixIteratorT in, static_assert((std::is_convertible::value), "UnaryPredicateOp's result type must be convertible to bool type"); - if (D <= 32) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 64) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); - } else if (D <= 128) { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + IndexT len = map_length * D; + constexpr int TPB = 128; + const int n_sm = raft::getMultiProcessorCount(); + // The following empirical heuristics enforce that we keep a good balance between having enough + // blocks and enough work per thread. + if (len < static_cast(32 * TPB * n_sm)) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); + } else if (len < static_cast(32 * 4 * TPB * n_sm)) { + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } else { - gatherKernel - <<>>(in, D, N, map, stencil, out, pred_op, transform_op); + using Policy = gather_policy; + IndexT n_blocks = raft::ceildiv(map_length * D, static_cast(Policy::stride)); + gather_kernel<<>>( + in, D, len, map, stencil, out, pred_op, transform_op); } RAFT_CUDA_TRY(cudaPeekAtLastError()); } @@ -160,10 +172,13 @@ void gatherImpl(const MatrixIteratorT in, /** * @brief gather copies rows from a source matrix into a destination matrix according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -174,13 +189,13 @@ void gatherImpl(const MatrixIteratorT in, * @param out Pointer to the output matrix (assumed to be row-major) * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, cudaStream_t stream) { typedef typename std::iterator_traits::value_type MapValueT; @@ -192,12 +207,15 @@ void gather(const MatrixIteratorT in, * @brief gather copies rows from a source matrix into a destination matrix according to a * transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -209,13 +227,17 @@ void gather(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, - int map_length, - MatrixIteratorT out, +template +void gather(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, + IndexT map_length, + OutputIteratorT out, MapTransformOp transform_op, cudaStream_t stream) { @@ -227,7 +249,7 @@ void gather(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -235,6 +257,9 @@ void gather(const MatrixIteratorT in, * simple pointer type). * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -247,17 +272,19 @@ void gather(const MatrixIteratorT in, * @param pred_op Predicate to apply to the stencil values * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename UnaryPredicateOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, cudaStream_t stream) { @@ -269,7 +296,7 @@ void gather_if(const MatrixIteratorT in, * @brief gather_if conditionally copies rows from a source matrix into a destination matrix * according to a transformed map. * - * @tparam MatrixIteratorT Random-access iterator type, for reading input matrix (may be a + * @tparam InputIteratorT Random-access iterator type, for reading input matrix (may be a * simple pointer type). * @tparam MapIteratorT Random-access iterator type, for reading input map (may be a simple * pointer type). @@ -278,7 +305,10 @@ void gather_if(const MatrixIteratorT in, * @tparam UnaryPredicateOp Unary lambda expression or operator type, UnaryPredicateOp's result * type must be convertible to bool type. * @tparam MapTransformOp Unary lambda expression or operator type, MapTransformOp's result - * type must be convertible to IndexT (= int) type. + * type must be convertible to IndexT type. + * @tparam OutputIteratorT Random-access iterator type, for writing output matrix (may be a + * simple pointer type). + * @tparam IndexT Index type. * * @param in Pointer to the input matrix (assumed to be row-major) * @param D Leading dimension of the input matrix 'in', which in-case of row-major @@ -292,18 +322,20 @@ void gather_if(const MatrixIteratorT in, * @param transform_op The transformation operation, transforms the map values to IndexT * @param stream CUDA stream to launch kernels within */ -template -void gather_if(const MatrixIteratorT in, - int D, - int N, - MapIteratorT map, + typename MapTransformOp, + typename OutputIteratorT, + typename IndexT> +void gather_if(const InputIteratorT in, + IndexT D, + IndexT N, + const MapIteratorT map, StencilIteratorT stencil, - int map_length, - MatrixIteratorT out, + IndexT map_length, + OutputIteratorT out, UnaryPredicateOp pred_op, MapTransformOp transform_op, cudaStream_t stream) diff --git a/cpp/include/raft/matrix/detail/linewise_op.cuh b/cpp/include/raft/matrix/detail/linewise_op.cuh index 605726bea6..ef8f0e88c1 100644 --- a/cpp/include/raft/matrix/detail/linewise_op.cuh +++ b/cpp/include/raft/matrix/detail/linewise_op.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -796,7 +796,8 @@ struct MatrixLinewiseOp { "layout for in and out must be either padded row or col major"); // also statically assert padded matrix alignment == 2^i*VecBytes - assert(raft::Pow2::areSameAlignOffsets(in, out)); + RAFT_EXPECTS(raft::Pow2::areSameAlignOffsets(in.data_handle(), out.data_handle()), + "The matrix views in and out does not have correct alignment"); if (alongLines) return matrixLinewiseVecRowsSpan +#include #include #include @@ -87,10 +87,10 @@ void seqRoot(math_t* in, if (a < math_t(0)) { return math_t(0); } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } } else { - return sqrt(a * scalar); + return raft::sqrt(a * scalar); } }, stream); @@ -194,7 +194,7 @@ void setValue(math_t* out, const math_t* in, math_t scalar, int len, cudaStream_ template void ratio( - const raft::handle_t& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) + raft::device_resources const& handle, math_t* src, math_t* dest, IdxType len, cudaStream_t stream) { auto d_src = src; auto d_dest = dest; @@ -278,7 +278,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return Type(0); else return a / b; @@ -294,7 +294,7 @@ void matrixVectorBinaryDivSkipZero(Type* data, rowMajor, bcastAlongRows, [] __device__(Type a, Type b) { - if (raft::myAbs(b) < Type(1e-10)) + if (raft::abs(b) < Type(1e-10)) return a; else return a / b; diff --git a/cpp/include/raft/matrix/detail/matrix.cuh b/cpp/include/raft/matrix/detail/matrix.cuh index 17a40be5d6..ef3a873d90 100644 --- a/cpp/include/raft/matrix/detail/matrix.cuh +++ b/cpp/include/raft/matrix/detail/matrix.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include @@ -299,7 +299,7 @@ void getDiagonalInverseMatrix(m_t* in, idx_t len, cudaStream_t stream) } template -m_t getL2Norm(const raft::handle_t& handle, const m_t* in, idx_t size, cudaStream_t stream) +m_t getL2Norm(raft::device_resources const& handle, const m_t* in, idx_t size, cudaStream_t stream) { cublasHandle_t cublasH = handle.get_cublas_handle(); m_t normval = 0; diff --git a/cpp/include/raft/matrix/detail/print.hpp b/cpp/include/raft/matrix/detail/print.hpp index fc3d14861c..814c6a0b4b 100644 --- a/cpp/include/raft/matrix/detail/print.hpp +++ b/cpp/include/raft/matrix/detail/print.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/topk.cuh b/cpp/include/raft/matrix/detail/select_k.cuh similarity index 59% rename from cpp/include/raft/spatial/knn/detail/topk.cuh rename to cpp/include/raft/matrix/detail/select_k.cuh index f4dcb53088..ac1ba3dfa3 100644 --- a/cpp/include/raft/spatial/knn/detail/topk.cuh +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,34 +16,34 @@ #pragma once -#include "topk/radix_topk.cuh" -#include "topk/warpsort_topk.cuh" +#include "select_radix.cuh" +#include "select_warpsort.cuh" #include #include #include -namespace raft::spatial::knn::detail { +namespace raft::matrix::detail { /** * Select k smallest or largest key/values from each row in the input data. * - * If you think of the input data `in_keys` as a row-major matrix with len columns and - * batch_size rows, then this function selects k smallest/largest values in each row and fills - * in the row-major matrix `out` of size (batch_size, k). + * If you think of the input data `in_val` as a row-major matrix with `len` columns and + * `batch_size` rows, then this function selects `k` smallest/largest values in each row and fills + * in the row-major matrix `out_val` of size (batch_size, k). * * @tparam T * the type of the keys (what is being compared). * @tparam IdxT * the index type (what is being selected together with the keys). * - * @param[in] in + * @param[in] in_val * contiguous device array of inputs of size (len * batch_size); * these are compared and selected. * @param[in] in_idx * contiguous device array of inputs of size (len * batch_size); - * typically, these are indices of the corresponding in_keys. + * typically, these are indices of the corresponding in_val. * @param batch_size * number of input rows, i.e. the batch size. * @param len @@ -51,12 +51,12 @@ namespace raft::spatial::knn::detail { * Invariant: len >= k. * @param k * the number of outputs to select in each input row. - * @param[out] out + * @param[out] out_val * contiguous device array of outputs of size (k * batch_size); - * the k smallest/largest values from each row of the `in_keys`. + * the k smallest/largest values from each row of the `in_val`. * @param[out] out_idx * contiguous device array of outputs of size (k * batch_size); - * the payload selected together with `out`. + * the payload selected together with `out_val`. * @param select_min * whether to select k smallest (true) or largest (false) keys. * @param stream @@ -64,28 +64,28 @@ namespace raft::spatial::knn::detail { * memory pool here to avoid memory allocations within the call). */ template -void select_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { common::nvtx::range fun_scope( - "matrix::select_topk(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); + "matrix::select_k(batch_size = %zu, len = %zu, k = %d)", batch_size, len, k); // TODO (achirkin): investigate the trade-off for a wider variety of inputs. const bool radix_faster = batch_size >= 64 && len >= 102400 && k >= 128; - if (k <= raft::spatial::knn::detail::topk::kMaxCapacity && !radix_faster) { - topk::warp_sort_topk( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + if (k <= select::warpsort::kMaxCapacity && !radix_faster) { + select::warpsort::select_k( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } else { - topk::radix_topk= 4 ? 11 : 8), 512>( - in, in_idx, batch_size, len, k, out, out_idx, select_min, stream, mr); + select::radix::select_k= 4 ? 11 : 8), 512>( + in_val, in_idx, batch_size, len, k, out_val, out_idx, select_min, stream, mr); } } -} // namespace raft::spatial::knn::detail +} // namespace raft::matrix::detail diff --git a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh similarity index 87% rename from cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh rename to cpp/include/raft/matrix/detail/select_radix.cuh index 9c0f20b706..de19e63a4c 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/radix_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -27,29 +28,29 @@ #include #include -#include +#include #include -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::radix { constexpr int ITEM_PER_THREAD = 32; constexpr int VECTORIZED_READ_SIZE = 16; template -__host__ __device__ constexpr int calc_num_buckets() +_RAFT_HOST_DEVICE constexpr int calc_num_buckets() { return 1 << BitsPerPass; } template -__host__ __device__ constexpr int calc_num_passes() +_RAFT_HOST_DEVICE constexpr int calc_num_passes() { return ceildiv(sizeof(T) * 8, BitsPerPass); } // Minimum reasonable block size for the given radix size. template -__host__ __device__ constexpr int calc_min_block_size() +_RAFT_HOST_DEVICE constexpr int calc_min_block_size() { return 1 << std::max(BitsPerPass - 4, Pow2::Log2 + 1); } @@ -62,7 +63,7 @@ __host__ __device__ constexpr int calc_min_block_size() * NB: Use pass=-1 for calc_mask(). */ template -__device__ constexpr int calc_start_bit(int pass) +_RAFT_DEVICE constexpr int calc_start_bit(int pass) { int start_bit = static_cast(sizeof(T) * 8) - (pass + 1) * BitsPerPass; if (start_bit < 0) { start_bit = 0; } @@ -70,7 +71,7 @@ __device__ constexpr int calc_start_bit(int pass) } template -__device__ constexpr unsigned calc_mask(int pass) +_RAFT_DEVICE constexpr unsigned calc_mask(int pass) { static_assert(BitsPerPass <= 31); int num_bits = calc_start_bit(pass - 1) - calc_start_bit(pass); @@ -82,7 +83,7 @@ __device__ constexpr unsigned calc_mask(int pass) * as of integers. */ template -__device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) +_RAFT_DEVICE typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) { auto bits = reinterpret_cast::UnsignedBits&>(key); bits = cub::Traits::TwiddleIn(bits); @@ -91,7 +92,7 @@ __device__ typename cub::Traits::UnsignedBits twiddle_in(T key, bool greater) } template -__device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) +_RAFT_DEVICE int calc_bucket(T x, int start_bit, unsigned mask, bool greater) { static_assert(BitsPerPass <= sizeof(int) * 8 - 1); // so return type can be int return (twiddle_in(x, greater) >> start_bit) & mask; @@ -112,7 +113,7 @@ __device__ int calc_bucket(T x, int start_bit, unsigned mask, bool greater) * @param f the lambda taking two arguments (T x, IdxT idx) */ template -__device__ void vectorized_process(const T* in, IdxT len, Func f) +_RAFT_DEVICE void vectorized_process(const T* in, IdxT len, Func f) { const IdxT stride = blockDim.x * gridDim.x; const int tid = blockIdx.x * blockDim.x + threadIdx.x; @@ -167,18 +168,18 @@ struct Counter { * (see steps 4-1 in `radix_kernel` description). */ template -__device__ void filter_and_histogram(const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, - T* out, - IdxT* out_idx, - IdxT len, - Counter* counter, - IdxT* histogram, - bool greater, - int pass, - int k) +_RAFT_DEVICE void filter_and_histogram(const T* in_buf, + const IdxT* in_idx_buf, + T* out_buf, + IdxT* out_idx_buf, + T* out, + IdxT* out_idx, + IdxT len, + Counter* counter, + IdxT* histogram, + bool greater, + int pass, + int k) { constexpr int num_buckets = calc_num_buckets(); __shared__ IdxT histogram_smem[num_buckets]; @@ -260,10 +261,10 @@ __device__ void filter_and_histogram(const T* in_buf, * (step 2 in `radix_kernel` description) */ template -__device__ void scan(volatile IdxT* histogram, - const int start, - const int num_buckets, - const IdxT current) +_RAFT_DEVICE void scan(volatile IdxT* histogram, + const int start, + const int num_buckets, + const IdxT current) { typedef cub::BlockScan BlockScan; __shared__ typename BlockScan::TempStorage temp_storage; @@ -284,7 +285,7 @@ __device__ void scan(volatile IdxT* histogram, * (steps 2-3 in `radix_kernel` description) */ template -__device__ void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) +_RAFT_DEVICE void choose_bucket(Counter* counter, IdxT* histogram, const IdxT k) { constexpr int num_buckets = calc_num_buckets(); int index = threadIdx.x; @@ -547,21 +548,21 @@ inline dim3 get_optimal_grid_size(size_t req_batch_size, size_t len) * memory pool here to avoid memory allocations within the call). */ template -void radix_topk(const T* in, - const IdxT* in_idx, - size_t batch_size, - size_t len, - int k, - T* out, - IdxT* out_idx, - bool select_min, - rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) +void select_k(const T* in, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) { // reduce the block size if the input length is too small. if constexpr (BlockSize > calc_min_block_size()) { if (BlockSize * ITEM_PER_THREAD > len) { - return radix_topk( + return select_k( in, in_idx, batch_size, len, k, out, out_idx, select_min, stream); } } @@ -573,23 +574,33 @@ void radix_topk(const T* in, dim3 blocks = get_optimal_grid_size(batch_size, len); size_t max_chunk_size = blocks.y; - auto pool_guard = raft::get_pool_memory_resource( - mr, - max_chunk_size * (sizeof(Counter) // counters - + sizeof(IdxT) * (num_buckets + 2) // histograms and IdxT bufs - + sizeof(T) * 2 // T bufs - )); + size_t req_aux = max_chunk_size * (sizeof(Counter) + num_buckets * sizeof(IdxT)); + size_t req_buf = max_chunk_size * len * 2 * (sizeof(T) + sizeof(IdxT)); + size_t mem_req = req_aux + req_buf; + size_t mem_free, mem_total; + RAFT_CUDA_TRY(cudaMemGetInfo(&mem_free, &mem_total)); + std::optional managed_memory; + rmm::mr::device_memory_resource* mr_buf = nullptr; + if (mem_req > mem_free) { + // if there's not enough memory for buffers on the device, resort to the managed memory. + mem_req = req_aux; + managed_memory.emplace(); + mr_buf = &managed_memory.value(); + } + + auto pool_guard = raft::get_pool_memory_resource(mr, mem_req); if (pool_guard) { - RAFT_LOG_DEBUG("radix_topk: using pool memory resource with initial size %zu bytes", + RAFT_LOG_DEBUG("radix::select_k: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + if (mr_buf == nullptr) { mr_buf = mr; } rmm::device_uvector> counters(max_chunk_size, stream, mr); - rmm::device_uvector histograms(num_buckets * max_chunk_size, stream, mr); - rmm::device_uvector buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf1(len * max_chunk_size, stream, mr); - rmm::device_uvector buf2(len * max_chunk_size, stream, mr); - rmm::device_uvector idx_buf2(len * max_chunk_size, stream, mr); + rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); + rmm::device_uvector buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf1(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector buf2(max_chunk_size * len, stream, mr_buf); + rmm::device_uvector idx_buf2(max_chunk_size * len, stream, mr_buf); for (size_t offset = 0; offset < batch_size; offset += max_chunk_size) { blocks.y = std::min(max_chunk_size, batch_size - offset); @@ -646,4 +657,4 @@ void radix_topk(const T* in, } } -} // namespace raft::spatial::knn::detail::topk +} // namespace raft::matrix::detail::select::radix diff --git a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh similarity index 71% rename from cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh rename to cpp/include/raft/matrix/detail/select_warpsort.cuh index cbe9f36e97..d362b73792 100644 --- a/cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,10 +16,11 @@ #pragma once -#include "bitonic_sort.cuh" - +#include #include +#include #include +#include #include #include @@ -31,12 +32,12 @@ /* Three APIs of different scopes are provided: - 1. host function: warp_sort_topk() + 1. host function: select_k() 2. block-wide API: class block_sort 3. warp-wide API: several implementations of warp_sort_* - 1. warp_sort_topk() + 1. select_k() (see the docstring) 2. class block_sort @@ -74,7 +75,7 @@ These two classes can be regarded as fixed size priority queue for a warp. Usage is similar to class block_sort. No shared memory is needed. - The host function (warp_sort_topk) uses a heuristic to choose between these two classes for + The host function (select_k) uses a heuristic to choose between these two classes for sorting, warp_sort_immediate being chosen when the number of inputs per warp is somewhat small (see the usage of LaunchThreshold::len_factor_for_choosing). @@ -94,7 +95,7 @@ } */ -namespace raft::spatial::knn::detail::topk { +namespace raft::matrix::detail::select::warpsort { static constexpr int kMaxCapacity = 256; @@ -102,18 +103,12 @@ namespace { /** Whether 'left` should indeed be on the left w.r.t. `right`. */ template -__device__ __forceinline__ auto is_ordered(T left, T right) -> bool +_RAFT_DEVICE _RAFT_FORCEINLINE auto is_ordered(T left, T right) -> bool { if constexpr (Ascending) { return left < right; } if constexpr (!Ascending) { return left > right; } } -constexpr auto calc_capacity(int k) -> int -{ - int capacity = isPo2(k) ? k : (1 << (log2(k) + 1)); - return capacity; -} - } // namespace /** @@ -134,7 +129,7 @@ constexpr auto calc_capacity(int k) -> int */ template class warp_sort { - static_assert(isPo2(Capacity)); + static_assert(is_a_power_of_two(Capacity)); static_assert(std::is_default_constructible_v); public: @@ -148,13 +143,16 @@ class warp_sort { /** The number of elements to select. */ const int k; + /** Extra memory required per-block for keeping the state (shared or global). */ + constexpr static auto mem_required(uint32_t block_size) -> size_t { return 0; } + /** * Construct the warp_sort empty queue. * * @param k * number of elements to select. */ - __device__ warp_sort(int k) : k(k) + _RAFT_DEVICE warp_sort(int k) : k(k) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -182,7 +180,7 @@ class warp_sort { * It serves as a conditional; when `false` the function does nothing. * We need it to ensure threads within a full warp don't diverge calling `bitonic::merge()`. */ - __device__ void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) + _RAFT_DEVICE void load_sorted(const T* in, const IdxT* in_idx, bool do_merge = true) { if (do_merge) { int idx = Pow2::mod(laneId()) ^ Pow2::Mask; @@ -198,7 +196,7 @@ class warp_sort { } } if (kWarpWidth < WarpSize || do_merge) { - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } } @@ -211,14 +209,23 @@ class warp_sort { * @param[out] out_idx * device pointer to a contiguous array, unique per-subwarp of size `kWarpWidth` * (length: k <= kWarpWidth * kMaxArrLen). + * @param valF (optional) postprocess values (T -> OutT) + * @param idxF (optional) postprocess indices (IdxT -> OutIdxT) */ - __device__ void store(T* out, IdxT* out_idx) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { int idx = Pow2::mod(laneId()); #pragma unroll kMaxArrLen for (int i = 0; i < kMaxArrLen && idx < k; i++, idx += kWarpWidth) { - out[idx] = val_arr_[i]; - out_idx[idx] = idx_arr_[i]; + out[idx] = valF(val_arr_[i]); + out_idx[idx] = idxF(idx_arr_[i]); } } @@ -245,8 +252,8 @@ class warp_sort { * the associated indices of the elements in the same format as `keys_in`. */ template - __device__ __forceinline__ void merge_in(const T* __restrict__ keys_in, - const IdxT* __restrict__ ids_in) + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_in(const T* __restrict__ keys_in, + const IdxT* __restrict__ ids_in) { #pragma unroll for (int i = std::min(kMaxArrLen, PerThreadSizeIn); i > 0; i--) { @@ -257,7 +264,7 @@ class warp_sort { idx_arr_[kMaxArrLen - i] = ids_in[PerThreadSizeIn - i]; } } - topk::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); + util::bitonic(Ascending, kWarpWidth).merge(val_arr_, idx_arr_); } }; @@ -275,8 +282,9 @@ class warp_sort_filtered : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_filtered(int k, T limit) + explicit _RAFT_DEVICE warp_sort_filtered(int k, T limit = kDummy) : warp_sort(k), buf_len_(0), k_th_(limit) { #pragma unroll @@ -286,12 +294,14 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ explicit warp_sort_filtered(int k) - : warp_sort_filtered(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_filtered{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // comparing for k_th should reduce the total amount of updates: // `false` means the input value is surely not in the top-k values. @@ -309,22 +319,22 @@ class warp_sort_filtered : public warp_sort { if (do_add) { add_to_buf_(val, idx); } } - __device__ void done() + _RAFT_DEVICE void done() { if (any(buf_len_ != 0)) { merge_buf_(); } } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); buf_len_ = 0; set_k_th_(); // contains warp sync @@ -334,7 +344,7 @@ class warp_sort_filtered : public warp_sort { } } - __device__ __forceinline__ void add_to_buf_(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE void add_to_buf_(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -373,8 +383,9 @@ class warp_sort_distributed : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_distributed(int k, T limit) + explicit _RAFT_DEVICE warp_sort_distributed(int k, T limit = kDummy) : warp_sort(k), buf_val_(kDummy), buf_idx_(IdxT{}), @@ -383,12 +394,14 @@ class warp_sort_distributed : public warp_sort { { } - __device__ __forceinline__ explicit warp_sort_distributed(int k) - : warp_sort_distributed(k, kDummy) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, + uint8_t* = nullptr, + T limit = kDummy) { + return warp_sort_distributed{k, limit}; } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE void add(T val, IdxT idx) { // mask tells which lanes in the warp have valid items to be added uint32_t mask = ballot(is_ordered(val, k_th_)); @@ -428,7 +441,7 @@ class warp_sort_distributed : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { merge_buf_(); @@ -437,16 +450,16 @@ class warp_sort_distributed : public warp_sort { } private: - __device__ __forceinline__ void set_k_th_() + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() { // NB on using srcLane: it's ok if it is outside the warp size / width; // the modulo op will be done inside the __shfl_sync. k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); } - __device__ __forceinline__ void merge_buf_() + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() { - topk::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val_, buf_idx_); this->merge_in<1>(&buf_val_, &buf_idx_); set_k_th_(); // contains warp sync buf_val_ = kDummy; @@ -463,6 +476,117 @@ class warp_sort_distributed : public warp_sort { T k_th_; }; +/** + * The same as `warp_sort_distributed`, but keeps the temporary value and index buffers + * in the given external pointers (normally, a shared memory pointer should be passed in). + */ +template +class warp_sort_distributed_ext : public warp_sort { + public: + using warp_sort::kDummy; + using warp_sort::kWarpWidth; + using warp_sort::k; + + constexpr static auto mem_required(uint32_t block_size) -> size_t + { + return (sizeof(T) + sizeof(IdxT)) * block_size; + } + + _RAFT_DEVICE warp_sort_distributed_ext(int k, T* val_buf, IdxT* idx_buf, T limit = kDummy) + : warp_sort(k), + val_buf_(val_buf), + idx_buf_(idx_buf), + buf_len_(0), + k_th_(limit) + { + val_buf_[laneId()] = kDummy; + } + + _RAFT_DEVICE static auto init_blockwide(int k, uint8_t* shmem, T limit = kDummy) + { + T* val_buf = nullptr; + IdxT* idx_buf = nullptr; + if constexpr (alignof(T) >= alignof(IdxT)) { + val_buf = reinterpret_cast(shmem); + idx_buf = reinterpret_cast(val_buf + blockDim.x); + } else { + idx_buf = reinterpret_cast(shmem); + val_buf = reinterpret_cast(idx_buf + blockDim.x); + } + auto warp_offset = Pow2::roundDown(threadIdx.x); + val_buf += warp_offset; + idx_buf += warp_offset; + return warp_sort_distributed_ext{k, val_buf, idx_buf, limit}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) + { + bool do_add = is_ordered(val, k_th_); + // mask tells which lanes in the warp have valid items to be added + uint32_t mask = ballot(do_add); + if (mask == 0) { return; } + // where to put the element in the tmp buffer + int dst_ix = buf_len_ + __popc(mask & ((1u << laneId()) - 1u)); + // put all elements, which fit into the current tmp buffer + if (do_add && dst_ix < WarpSize) { + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + do_add = false; + } + // Total number of elements to be added + buf_len_ += __popc(mask); + // If the buffer is still not full, we can return + if (buf_len_ < WarpSize) { return; } + // Otherwise, merge the warp tmp buffer into the queue + merge_buf_(); // implies warp sync + buf_len_ -= WarpSize; + // save the inputs that couldn't fit before the merge + if (do_add) { + dst_ix -= WarpSize; + val_buf_[dst_ix] = val; + idx_buf_[dst_ix] = idx; + } + } + + _RAFT_DEVICE void done() + { + if (buf_len_ != 0) { + merge_buf_(); + buf_len_ = 0; + } + __syncthreads(); + } + + private: + _RAFT_DEVICE _RAFT_FORCEINLINE void set_k_th_() + { + // NB on using srcLane: it's ok if it is outside the warp size / width; + // the modulo op will be done inside the __shfl_sync. + k_th_ = shfl(val_arr_[kMaxArrLen - 1], k - 1, kWarpWidth); + } + + _RAFT_DEVICE _RAFT_FORCEINLINE void merge_buf_() + { + __syncwarp(); // make sure the threads are aware of the data written by others + T buf_val = val_buf_[laneId()]; + IdxT buf_idx = idx_buf_[laneId()]; + val_buf_[laneId()] = kDummy; + util::bitonic<1>(!Ascending, kWarpWidth).sort(buf_val, buf_idx); + this->merge_in<1>(&buf_val, &buf_idx); + set_k_th_(); // contains warp sync + } + + using warp_sort::kMaxArrLen; + using warp_sort::val_arr_; + using warp_sort::idx_arr_; + + T* val_buf_; + IdxT* idx_buf_; + uint32_t buf_len_; // 0 <= buf_len_ < WarpSize + + T k_th_; +}; + /** * This version of warp_sort adds every input element into the intermediate sorting * buffer, and thus does the sorting step every `Capacity` input elements. @@ -475,8 +599,10 @@ class warp_sort_immediate : public warp_sort { using warp_sort::kDummy; using warp_sort::kWarpWidth; using warp_sort::k; + using warp_sort::mem_required; - __device__ warp_sort_immediate(int k) : warp_sort(k), buf_len_(0) + explicit _RAFT_DEVICE warp_sort_immediate(int k) + : warp_sort(k), buf_len_(0) { #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -485,7 +611,12 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void add(T val, IdxT idx) + _RAFT_DEVICE _RAFT_FORCEINLINE static auto init_blockwide(int k, uint8_t* = nullptr) + { + return warp_sort_immediate{k}; + } + + _RAFT_DEVICE void add(T val, IdxT idx) { // NB: the loop is used here to ensure the constant indexing, // to not force the buffers spill into the local memory. @@ -499,7 +630,7 @@ class warp_sort_immediate : public warp_sort { ++buf_len_; if (buf_len_ == kMaxArrLen) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); #pragma unroll for (int i = 0; i < kMaxArrLen; i++) { @@ -509,10 +640,10 @@ class warp_sort_immediate : public warp_sort { } } - __device__ void done() + _RAFT_DEVICE void done() { if (buf_len_ != 0) { - topk::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); + util::bitonic(!Ascending, kWarpWidth).sort(val_buf_, idx_buf_); this->merge_in(val_buf_, idx_buf_); } } @@ -544,15 +675,11 @@ class block_sort { using queue_t = WarpSortWarpWide; template - __device__ block_sort(int k, uint8_t* smem_buf, Args... args) : queue_(k, args...) + _RAFT_DEVICE block_sort(int k, Args... args) : queue_(queue_t::init_blockwide(k, args...)) { - val_smem_ = reinterpret_cast(smem_buf); - const int num_of_warp = subwarp_align::div(blockDim.x); - idx_smem_ = reinterpret_cast( - smem_buf + Pow2<256>::roundUp(ceildiv(num_of_warp, 2) * sizeof(T) * k)); } - __device__ void add(T val, IdxT idx) { queue_.add(val, idx); } + _RAFT_DEVICE void add(T val, IdxT idx) { queue_.add(val, idx); } /** * At the point of calling this function, the warp-level queues consumed all input @@ -560,22 +687,26 @@ class block_sort { * * Here we tree-merge the results using the shared memory and block sync. */ - __device__ void done() + _RAFT_DEVICE void done(uint8_t* smem_buf) { queue_.done(); + int nwarps = subwarp_align::div(blockDim.x); + auto val_smem = reinterpret_cast(smem_buf); + auto idx_smem = reinterpret_cast( + smem_buf + Pow2<256>::roundUp(ceildiv(nwarps, 2) * sizeof(T) * queue_.k)); + const int warp_id = subwarp_align::div(threadIdx.x); // NB: there is no need for the second __synchthreads between .load_sorted and .store: // we shift the pointers every iteration, such that individual warps either access the same // locations or do not overlap with any of the other warps. The access patterns within warps // are different for the two functions, but .load_sorted implies warp sync at the end, so // there is no need for __syncwarp either. - for (int shift_mask = ~0, nwarps = subwarp_align::div(blockDim.x), split = (nwarps + 1) >> 1; - nwarps > 1; + for (int shift_mask = ~0, split = (nwarps + 1) >> 1; nwarps > 1; nwarps = split, split = (nwarps + 1) >> 1) { if (warp_id < nwarps && warp_id >= split) { int dst_warp_shift = (warp_id - (split & shift_mask)) * queue_.k; - queue_.store(val_smem_ + dst_warp_shift, idx_smem_ + dst_warp_shift); + queue_.store(val_smem + dst_warp_shift, idx_smem + dst_warp_shift); } __syncthreads(); @@ -585,22 +716,27 @@ class block_sort { // The last argument serves as a condition for loading // -- to make sure threads within a full warp do not diverge on `bitonic::merge()` queue_.load_sorted( - val_smem_ + src_warp_shift, idx_smem_ + src_warp_shift, warp_id < nwarps - split); + val_smem + src_warp_shift, idx_smem + src_warp_shift, warp_id < nwarps - split); } } } /** Save the content by the pointer location. */ - __device__ void store(T* out, IdxT* out_idx) const + template + _RAFT_DEVICE void store(OutT* out, + OutIdxT* out_idx, + ValF valF = raft::identity_op{}, + IdxF idxF = raft::identity_op{}) const { - if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx); } + if (threadIdx.x < subwarp_align::Value) { queue_.store(out, out_idx, valF, idxF); } } private: using subwarp_align = Pow2; queue_t queue_; - T* val_smem_; - IdxT* idx_smem_; }; /** @@ -618,7 +754,10 @@ __launch_bounds__(256) __global__ void block_kernel(const T* in, const IdxT* in_idx, IdxT len, int k, T* out, IdxT* out_idx) { extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; - block_sort queue(k, smem_buf_bytes); + using bq_t = block_sort; + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(k, warp_smem); + in += blockIdx.y * len; if (in_idx != nullptr) { in_idx += blockIdx.y * len; } @@ -629,7 +768,7 @@ __launch_bounds__(256) __global__ (i < len && in_idx != nullptr) ? __ldcs(in_idx + i) : i); } - queue.done(); + queue.done(smem_buf_bytes); const int block_id = blockIdx.x + gridDim.x * blockIdx.y; queue.store(out + block_id * k, out_idx + block_id * k); } @@ -656,7 +795,7 @@ struct launch_setup { int* min_grid_size, int block_size_limit = 0) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::calc_optimal_params( @@ -689,7 +828,7 @@ struct launch_setup { IdxT* out_idx, rmm::cuda_stream_view stream) { - const int capacity = calc_capacity(k); + const int capacity = bound_by_power_of_two(k); if constexpr (Capacity > 1) { if (capacity < Capacity) { return launch_setup::kernel(k, @@ -740,6 +879,18 @@ struct LaunchThreshold { static constexpr int len_factor_for_single_block = 32; }; +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + +template <> +struct LaunchThreshold { + static constexpr int len_factor_for_multi_block = 2; + static constexpr int len_factor_for_single_block = 32; +}; + template <> struct LaunchThreshold { static constexpr int len_factor_for_choosing = 4; @@ -751,7 +902,7 @@ template