diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index 52d550d375..e37df52a4f 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -66,19 +66,17 @@ jobs: run_script: "ci/build_docs.sh" wheel-build-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} sha: ${{ inputs.sha }} date: ${{ inputs.date }} - package-name: pylibraft - package-dir: python/pylibraft - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + script: ci/build_wheel_pylibraft.sh wheel-publish-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} @@ -88,19 +86,17 @@ jobs: wheel-build-raft-dask: needs: wheel-publish-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} sha: ${{ inputs.sha }} date: ${{ inputs.date }} - package-name: raft_dask - package-dir: python/raft-dask - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + script: ci/build_wheel_raft_dask.sh wheel-publish-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-publish.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-publish.yaml@branch-23.10 with: build_type: ${{ inputs.build_type || 'branch' }} branch: ${{ inputs.branch }} diff --git a/.github/workflows/pr.yaml b/.github/workflows/pr.yaml index 75d65f9175..4437e0dc85 100644 --- a/.github/workflows/pr.yaml +++ b/.github/workflows/pr.yaml @@ -67,40 +67,28 @@ jobs: wheel-build-pylibraft: needs: checks secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 with: build_type: pull-request - package-name: pylibraft - package-dir: python/pylibraft - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + script: ci/build_wheel_pylibraft.sh wheel-tests-pylibraft: needs: wheel-build-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 with: build_type: pull-request - package-name: pylibraft - test-unittest: "python -m pytest ./python/pylibraft/pylibraft/test" - test-smoketest: "python ./ci/wheel_smoke_test_pylibraft.py" + script: ci/test_wheel_pylibraft.sh wheel-build-raft-dask: needs: wheel-tests-pylibraft secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-build.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-build.yaml@branch-23.10 with: build_type: pull-request - package-name: raft_dask - package-dir: python/raft-dask - before-wheel: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft && python -m pip install --no-deps ./local-pylibraft/pylibraft*.whl" - skbuild-configure-options: "-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + script: "ci/build_wheel_raft_dask.sh" wheel-tests-raft-dask: needs: wheel-build-raft-dask secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 with: build_type: pull-request - package-name: raft_dask - # Always want to test against latest dask/distributed. - test-before-amd64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10" - test-before-arm64: "RAPIDS_PY_WHEEL_NAME=pylibraft_${{ '${PIP_CU_VERSION}' }} rapids-download-wheels-from-s3 ./local-pylibraft-dep && pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10" - test-unittest: "python -m pytest ./python/raft-dask/raft_dask/test" - test-smoketest: "python ./ci/wheel_smoke_test_raft_dask.py" + script: ci/test_wheel_raft_dask.sh diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index aa1838e7bc..a80d5ff0cf 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -32,23 +32,19 @@ jobs: sha: ${{ inputs.sha }} wheel-tests-pylibraft: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 with: build_type: nightly branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} - package-name: pylibraft - test-unittest: "python -m pytest -v ./python/pylibraft/pylibraft/test" + script: ci/test_wheel_pylibraft.sh wheel-tests-raft-dask: secrets: inherit - uses: rapidsai/shared-action-workflows/.github/workflows/wheels-manylinux-test.yml@branch-23.10 + uses: rapidsai/shared-action-workflows/.github/workflows/wheels-test.yaml@branch-23.10 with: build_type: nightly branch: ${{ inputs.branch }} date: ${{ inputs.date }} sha: ${{ inputs.sha }} - package-name: raft_dask - test-before-amd64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10" - test-before-arm64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10" - test-unittest: "python -m pytest -v ./python/raft-dask/raft_dask/test" + script: ci/test_wheel_raft_dask.sh diff --git a/ci/build_wheel.sh b/ci/build_wheel.sh new file mode 100755 index 0000000000..a9f7f64294 --- /dev/null +++ b/ci/build_wheel.sh @@ -0,0 +1,30 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. + +set -euo pipefail + +package_name=$1 +package_dir=$2 + +source rapids-configure-sccache +source rapids-date-string + +# Use gha-tools rapids-pip-wheel-version to generate wheel version then +# update the necessary files +version_override="$(rapids-pip-wheel-version ${RAPIDS_DATE_STRING})" + +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" + +ci/release/apply_wheel_modifications.sh ${version_override} "-${RAPIDS_PY_CUDA_SUFFIX}" +echo "The package name and/or version was modified in the package source. The git diff is:" +git diff + +cd "${package_dir}" + +# Hardcode the output dir +python -m pip wheel . -w dist -vvv --no-deps --disable-pip-version-check + +mkdir -p final_dist +python -m auditwheel repair -w final_dist dist/* + +RAPIDS_PY_WHEEL_NAME="${package_name}_${RAPIDS_PY_CUDA_SUFFIX}" rapids-upload-wheels-to-s3 final_dist diff --git a/ci/build_wheel_pylibraft.sh b/ci/build_wheel_pylibraft.sh new file mode 100755 index 0000000000..f17f038675 --- /dev/null +++ b/ci/build_wheel_pylibraft.sh @@ -0,0 +1,9 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. + +set -euo pipefail + +# Set up skbuild options. Enable sccache in skbuild config options +export SKBUILD_CONFIGURE_OPTIONS="-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + +ci/build_wheel.sh pylibraft python/pylibraft diff --git a/ci/build_wheel_raft_dask.sh b/ci/build_wheel_raft_dask.sh new file mode 100755 index 0000000000..f0204d45c0 --- /dev/null +++ b/ci/build_wheel_raft_dask.sh @@ -0,0 +1,14 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. + +set -euo pipefail + +# Set up skbuild options. Enable sccache in skbuild config options +export SKBUILD_CONFIGURE_OPTIONS="-DRAFT_BUILD_WHEELS=ON -DDETECT_CONDA_ENV=OFF -DFIND_RAFT_CPP=OFF" + +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" + +RAPIDS_PY_WHEEL_NAME=pylibraft_${RAPIDS_PY_CUDA_SUFFIX} rapids-download-wheels-from-s3 ./local-pylibraft +python -m pip install --no-deps ./local-pylibraft/pylibraft*.whl + +ci/build_wheel.sh raft_dask python/raft-dask diff --git a/ci/release/update-version.sh b/ci/release/update-version.sh index ef935ba518..6a7e319f5d 100755 --- a/ci/release/update-version.sh +++ b/ci/release/update-version.sh @@ -51,6 +51,9 @@ sed_runner "s/__version__ = .*/__version__ = \"${NEXT_FULL_TAG}\"/g" python/raft sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/pylibraft/pyproject.toml sed_runner "s/^version = .*/version = \"${NEXT_FULL_TAG}\"/g" python/raft-dask/pyproject.toml +# Wheel testing script +sed_runner "s/branch-.*/branch-${NEXT_SHORT_TAG}/g" ci/test_wheel_raft_dask.sh + # Docs update sed_runner 's/version = .*/version = '"'${NEXT_SHORT_TAG}'"'/g' docs/source/conf.py sed_runner 's/release = .*/release = '"'${NEXT_FULL_TAG}'"'/g' docs/source/conf.py diff --git a/ci/test_cpp.sh b/ci/test_cpp.sh index e32697a68a..9c487be156 100755 --- a/ci/test_cpp.sh +++ b/ci/test_cpp.sh @@ -36,12 +36,7 @@ trap "EXITCODE=1" ERR set +e # Run libraft gtests from libraft-tests package -rapids-logger "Run gtests" -for gt in "$CONDA_PREFIX"/bin/gtests/libraft/* ; do - test_name=$(basename ${gt}) - echo "Running gtest $test_name" - ${gt} --gtest_output=xml:${RAPIDS_TESTS_DIR} -done +ctest -j8 --output-on-failure rapids-logger "Test script exiting with value: $EXITCODE" exit ${EXITCODE} diff --git a/ci/test_wheel_pylibraft.sh b/ci/test_wheel_pylibraft.sh new file mode 100755 index 0000000000..d990a0e6c2 --- /dev/null +++ b/ci/test_wheel_pylibraft.sh @@ -0,0 +1,18 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. + +set -euo pipefail + +mkdir -p ./dist +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" +RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./dist + +# echo to expand wildcard before adding `[extra]` requires for pip +python -m pip install $(echo ./dist/pylibraft*.whl)[test] + +# Run smoke tests for aarch64 pull requests +if [[ "$(arch)" == "aarch64" && "${RAPIDS_BUILD_TYPE}" == "pull-request" ]]; then + python ./ci/wheel_smoke_test_pylibraft.py +else + python -m pytest ./python/pylibraft/pylibraft/test +fi diff --git a/ci/test_wheel_raft_dask.sh b/ci/test_wheel_raft_dask.sh new file mode 100755 index 0000000000..676d642de9 --- /dev/null +++ b/ci/test_wheel_raft_dask.sh @@ -0,0 +1,25 @@ +#!/bin/bash +# Copyright (c) 2023, NVIDIA CORPORATION. + +set -euo pipefail + +mkdir -p ./dist +RAPIDS_PY_CUDA_SUFFIX="$(rapids-wheel-ctk-name-gen ${RAPIDS_CUDA_VERSION})" +RAPIDS_PY_WHEEL_NAME="raft_dask_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./dist + +# Download the pylibraft built in the previous step +RAPIDS_PY_WHEEL_NAME="pylibraft_${RAPIDS_PY_CUDA_SUFFIX}" rapids-download-wheels-from-s3 ./local-pylibraft-dep +python -m pip install --no-deps ./local-pylibraft-dep/pylibraft*.whl + +# Always install latest dask for testing +python -m pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.10 + +# echo to expand wildcard before adding `[extra]` requires for pip +python -m pip install $(echo ./dist/raft_dask*.whl)[test] + +# Run smoke tests for aarch64 pull requests +if [[ "$(arch)" == "aarch64" && "${RAPIDS_BUILD_TYPE}" == "pull-request" ]]; then + python ./ci/wheel_smoke_test_raft_dask.py +else + python -m pytest ./python/raft-dask/raft_dask/test +fi diff --git a/cpp/bench/ann/src/common/conf.cpp b/cpp/bench/ann/src/common/conf.cpp index f690f68783..d180f37973 100644 --- a/cpp/bench/ann/src/common/conf.cpp +++ b/cpp/bench/ann/src/common/conf.cpp @@ -78,7 +78,7 @@ void Configuration::parse_dataset_(const nlohmann::json& conf) } else if (!filename.compare(filename.size() - 5, 5, "i8bin")) { dataset_conf_.dtype = "int8"; } else { - log_error("Could not determine data type of the dataset"); + log_error("Could not determine data type of the dataset %s", filename.c_str()); } } } diff --git a/cpp/bench/ann/src/common/dataset.h b/cpp/bench/ann/src/common/dataset.h index 46dd66d649..ae05cd02a1 100644 --- a/cpp/bench/ann/src/common/dataset.h +++ b/cpp/bench/ann/src/common/dataset.h @@ -14,21 +14,27 @@ * limitations under the License. */ #pragma once + +#include + +#ifndef CPU_ONLY #include +#include +#else +typedef uint16_t half; +#endif + #include #include #include #include -#include #include #include #include #include #include -#include - namespace raft::bench::ann { // http://big-ann-benchmarks.com/index.html: @@ -46,13 +52,17 @@ class BinFile { const std::string& mode, uint32_t subset_first_row = 0, uint32_t subset_size = 0); - ~BinFile() { fclose(fp_); } + ~BinFile() + { + if (fp_) { fclose(fp_); } + } BinFile(const BinFile&) = delete; BinFile& operator=(const BinFile&) = delete; - void get_shape(size_t* nrows, int* ndims) + void get_shape(size_t* nrows, int* ndims) const { assert(read_mode_); + if (!fp_) { open_file_(); } *nrows = nrows_; *ndims = ndims_; } @@ -60,6 +70,7 @@ class BinFile { void read(T* data) const { assert(read_mode_); + if (!fp_) { open_file_(); } size_t total = static_cast(nrows_) * ndims_; if (fread(data, sizeof(T), total, fp_) != total) { throw std::runtime_error("fread() BinFile " + file_ + " failed"); @@ -69,6 +80,7 @@ class BinFile { void write(const T* data, uint32_t nrows, uint32_t ndims) { assert(!read_mode_); + if (!fp_) { open_file_(); } if (fwrite(&nrows, sizeof(uint32_t), 1, fp_) != 1) { throw std::runtime_error("fwrite() BinFile " + file_ + " failed"); } @@ -82,34 +94,41 @@ class BinFile { } } - void* map() const + T* map() const { assert(read_mode_); - int fid = fileno(fp_); - auto mmap_ptr = mmap(NULL, file_size_, PROT_READ, MAP_PRIVATE, fid, 0); - if (mmap_ptr == MAP_FAILED) { + if (!fp_) { open_file_(); } + int fid = fileno(fp_); + mapped_ptr_ = mmap(nullptr, file_size_, PROT_READ, MAP_PRIVATE, fid, 0); + if (mapped_ptr_ == MAP_FAILED) { throw std::runtime_error("mmap error: Value of errno " + std::to_string(errno) + ", " + std::string(strerror(errno))); } - return mmap_ptr; + return reinterpret_cast(reinterpret_cast(mapped_ptr_) + 2 * sizeof(uint32_t) + + subset_first_row_ * ndims_ * sizeof(T)); } - void unmap(void* data) const + void unmap() const { - if (munmap(data, file_size_) == -1) { + if (munmap(mapped_ptr_, file_size_) == -1) { throw std::runtime_error("munmap error: " + std::string(strerror(errno))); } } private: void check_suffix_(); + void open_file_() const; std::string file_; - FILE* fp_; bool read_mode_; - uint32_t nrows_; - uint32_t ndims_; - size_t file_size_; + uint32_t subset_first_row_; + uint32_t subset_size_; + + mutable FILE* fp_; + mutable uint32_t nrows_; + mutable uint32_t ndims_; + mutable size_t file_size_; + mutable void* mapped_ptr_; }; template @@ -117,23 +136,32 @@ BinFile::BinFile(const std::string& file, const std::string& mode, uint32_t subset_first_row, uint32_t subset_size) - : file_(file) + : file_(file), + read_mode_(mode == "r"), + subset_first_row_(subset_first_row), + subset_size_(subset_size), + fp_(nullptr) { check_suffix_(); - if (mode == "r") { - read_mode_ = true; - } else if (mode == "w") { - read_mode_ = false; - if (subset_first_row != 0) { - throw std::runtime_error("subset_first_row should be zero for write mode"); + if (!read_mode_) { + if (mode == "w") { + if (subset_first_row != 0) { + throw std::runtime_error("subset_first_row should be zero for write mode"); + } + if (subset_size != 0) { + throw std::runtime_error("subset_size should be zero for write mode"); + } + } else { + throw std::runtime_error("BinFile's mode must be either 'r' or 'w': " + file_); } - if (subset_size != 0) { throw std::runtime_error("subset_size should be zero for write mode"); } - } else { - throw std::runtime_error("BinFile's mode must be either 'r' or 'w': " + file_); } +} - fp_ = fopen(file_.c_str(), mode.c_str()); +template +void BinFile::open_file_() const +{ + fp_ = fopen(file_.c_str(), read_mode_ ? "r" : "w"); if (!fp_) { throw std::runtime_error("open BinFile failed: " + file_); } if (read_mode_) { @@ -156,24 +184,24 @@ BinFile::BinFile(const std::string& file, std::to_string(file_size_)); } - if (subset_first_row >= nrows_) { - throw std::runtime_error(file_ + ": subset_first_row (" + std::to_string(subset_first_row) + + if (subset_first_row_ >= nrows_) { + throw std::runtime_error(file_ + ": subset_first_row (" + std::to_string(subset_first_row_) + ") >= nrows (" + std::to_string(nrows_) + ")"); } - if (subset_first_row + subset_size > nrows_) { - throw std::runtime_error(file_ + ": subset_first_row (" + std::to_string(subset_first_row) + - ") + subset_size (" + std::to_string(subset_size) + ") > nrows (" + + if (subset_first_row_ + subset_size_ > nrows_) { + throw std::runtime_error(file_ + ": subset_first_row (" + std::to_string(subset_first_row_) + + ") + subset_size (" + std::to_string(subset_size_) + ") > nrows (" + std::to_string(nrows_) + ")"); } - if (subset_first_row) { + if (subset_first_row_) { static_assert(sizeof(long) == 8, "fseek() don't support 64-bit offset"); - if (fseek(fp_, sizeof(T) * subset_first_row * ndims_, SEEK_CUR) == -1) { + if (fseek(fp_, sizeof(T) * subset_first_row_ * ndims_, SEEK_CUR) == -1) { throw std::runtime_error(file_ + ": fseek failed"); } - nrows_ -= subset_first_row; + nrows_ -= subset_first_row_; } - if (subset_size) { nrows_ = subset_size; } + if (subset_size_) { nrows_ = subset_size_; } } } @@ -225,9 +253,9 @@ class Dataset { std::string name() const { return name_; } std::string distance() const { return distance_; } - int dim() const { return dim_; } - size_t base_set_size() const { return base_set_size_; } - size_t query_set_size() const { return query_set_size_; } + virtual int dim() const = 0; + virtual size_t base_set_size() const = 0; + virtual size_t query_set_size() const = 0; // load data lazily, so don't pay the overhead of reading unneeded set // e.g. don't load base set when searching @@ -254,9 +282,6 @@ class Dataset { std::string name_; std::string distance_; - int dim_; - size_t base_set_size_; - size_t query_set_size_; mutable T* base_set_ = nullptr; mutable T* query_set_ = nullptr; @@ -270,31 +295,37 @@ Dataset::~Dataset() { delete[] base_set_; delete[] query_set_; - if (d_base_set_) { RAFT_CUDA_TRY_NO_THROW(cudaFree(d_base_set_)); } - if (d_query_set_) { RAFT_CUDA_TRY_NO_THROW(cudaFree(d_query_set_)); } +#ifndef CPU_ONLY + if (d_base_set_) { cudaFree(d_base_set_); } + if (d_query_set_) { cudaFree(d_query_set_); } +#endif } template const T* Dataset::base_set_on_gpu() const { +#ifndef CPU_ONLY if (!d_base_set_) { base_set(); - RAFT_CUDA_TRY(cudaMalloc((void**)&d_base_set_, base_set_size_ * dim_ * sizeof(T))); + RAFT_CUDA_TRY(cudaMalloc((void**)&d_base_set_, base_set_size() * dim() * sizeof(T))); RAFT_CUDA_TRY(cudaMemcpy( - d_base_set_, base_set_, base_set_size_ * dim_ * sizeof(T), cudaMemcpyHostToDevice)); + d_base_set_, base_set_, base_set_size() * dim() * sizeof(T), cudaMemcpyHostToDevice)); } +#endif return d_base_set_; } template const T* Dataset::query_set_on_gpu() const { +#ifndef CPU_ONLY if (!d_query_set_) { query_set(); - RAFT_CUDA_TRY(cudaMalloc((void**)&d_query_set_, query_set_size_ * dim_ * sizeof(T))); + RAFT_CUDA_TRY(cudaMalloc((void**)&d_query_set_, query_set_size() * dim() * sizeof(T))); RAFT_CUDA_TRY(cudaMemcpy( - d_query_set_, query_set_, query_set_size_ * dim_ * sizeof(T), cudaMemcpyHostToDevice)); + d_query_set_, query_set_, query_set_size() * dim() * sizeof(T), cudaMemcpyHostToDevice)); } +#endif return d_query_set_; } @@ -316,24 +347,24 @@ class BinDataset : public Dataset { const std::string& distance); ~BinDataset() { - if (this->mapped_base_set_) { - base_file_.unmap(reinterpret_cast(this->mapped_base_set_) - subset_offset_); - } + if (this->mapped_base_set_) { base_file_.unmap(); } } + int dim() const override; + size_t base_set_size() const override; + size_t query_set_size() const override; + private: void load_base_set_() const override; void load_query_set_() const override; void map_base_set_() const override; - using Dataset::dim_; - using Dataset::base_set_size_; - using Dataset::query_set_size_; + mutable int dim_ = 0; + mutable size_t base_set_size_ = 0; + mutable size_t query_set_size_ = 0; BinFile base_file_; BinFile query_file_; - - size_t subset_offset_; }; template @@ -345,37 +376,71 @@ BinDataset::BinDataset(const std::string& name, const std::string& distance) : Dataset(name, distance), base_file_(base_file, "r", subset_first_row, subset_size), - query_file_(query_file, "r"), - subset_offset_(2 * sizeof(uint32_t) + subset_first_row * dim_ * sizeof(T)) + query_file_(query_file, "r") +{ +} + +template +int BinDataset::dim() const +{ + if (dim_ > 0) { return dim_; } + if (base_set_size() > 0) { return dim_; } + if (query_set_size() > 0) { return dim_; } + return dim_; +} + +template +size_t BinDataset::query_set_size() const { - base_file_.get_shape(&base_set_size_, &dim_); - int query_dim; - query_file_.get_shape(&query_set_size_, &query_dim); - if (query_dim != dim_) { + if (query_set_size_ > 0) { return query_set_size_; } + int dim; + query_file_.get_shape(&query_set_size_, &dim); + if (query_set_size_ == 0) { throw std::runtime_error("Zero query set size"); } + if (dim == 0) { throw std::runtime_error("Zero query set dim"); } + if (dim_ == 0) { + dim_ = dim; + } else if (dim_ != dim) { throw std::runtime_error("base set dim (" + std::to_string(dim_) + ") != query set dim (" + - std::to_string(query_dim)); + std::to_string(dim)); + } + return query_set_size_; +} + +template +size_t BinDataset::base_set_size() const +{ + if (base_set_size_ > 0) { return base_set_size_; } + int dim; + base_file_.get_shape(&base_set_size_, &dim); + if (base_set_size_ == 0) { throw std::runtime_error("Zero base set size"); } + if (dim == 0) { throw std::runtime_error("Zero base set dim"); } + if (dim_ == 0) { + dim_ = dim; + } else if (dim_ != dim) { + throw std::runtime_error("base set dim (" + std::to_string(dim) + ") != query set dim (" + + std::to_string(dim_)); } + return base_set_size_; } template void BinDataset::load_base_set_() const { - this->base_set_ = new T[base_set_size_ * dim_]; + this->base_set_ = new T[base_set_size() * dim()]; base_file_.read(this->base_set_); } template void BinDataset::load_query_set_() const { - this->query_set_ = new T[query_set_size_ * dim_]; + this->query_set_ = new T[query_set_size() * dim()]; query_file_.read(this->query_set_); } template void BinDataset::map_base_set_() const { - char* original_map_ptr = static_cast(base_file_.map()); - this->mapped_base_set_ = reinterpret_cast(original_map_ptr + subset_offset_); + this->mapped_base_set_ = base_file_.map(); } } // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index 22204c2b61..b43f52eb5c 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -132,13 +132,18 @@ void parse_build_param(const nlohmann::json& conf, param.graph_degree = conf.at("index_dim"); param.intermediate_graph_degree = param.graph_degree * 2; } + if (conf.contains("intermediate_graph_degree")) { + param.intermediate_graph_degree = conf.at("intermediate_graph_degree"); + } } template void parse_search_param(const nlohmann::json& conf, typename raft::bench::ann::RaftCagra::SearchParam& param) { - param.itopk_size = conf.at("itopk"); + if (conf.contains("itopk")) { param.p.itopk_size = conf.at("itopk"); } + if (conf.contains("search_width")) { param.p.num_parents = conf.at("search_width"); } + if (conf.contains("max_iterations")) { param.p.max_iterations = conf.at("max_iterations"); } } #endif diff --git a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h index 399fd6a0a8..e898a13636 100644 --- a/cpp/bench/ann/src/raft/raft_cagra_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_cagra_wrapper.h @@ -47,10 +47,10 @@ class RaftCagra : public ANN { using typename ANN::AnnSearchParam; struct SearchParam : public AnnSearchParam { - unsigned itopk_size; + raft::neighbors::experimental::cagra::search_params p; }; - using BuildParam = raft::neighbors::experimental::cagra::index_params; + using BuildParam = raft::neighbors::cagra::index_params; RaftCagra(Metric metric, int dim, const BuildParam& param); @@ -71,7 +71,7 @@ class RaftCagra : public ANN { AlgoProperty get_property() const override { AlgoProperty property; - property.dataset_memory_type = MemoryType::Device; + property.dataset_memory_type = MemoryType::Host; property.query_memory_type = MemoryType::Device; property.need_dataset_when_search = true; return property; @@ -82,8 +82,8 @@ class RaftCagra : public ANN { private: raft::device_resources handle_; BuildParam index_params_; - raft::neighbors::experimental::cagra::search_params search_params_; - std::optional> index_; + raft::neighbors::cagra::search_params search_params_; + std::optional> index_; int device_; int dimension_; rmm::mr::pool_memory_resource mr_; @@ -104,28 +104,36 @@ RaftCagra::RaftCagra(Metric metric, int dim, const BuildParam& param) template void RaftCagra::build(const T* dataset, size_t nrow, cudaStream_t) { - auto dataset_view = raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); - index_.emplace(raft::neighbors::experimental::cagra::build(handle_, index_params_, dataset_view)); + if (get_property().dataset_memory_type == MemoryType::Host) { + auto dataset_view = raft::make_host_matrix_view(dataset, IdxT(nrow), dimension_); + index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); + } else { + auto dataset_view = + raft::make_device_matrix_view(dataset, IdxT(nrow), dimension_); + index_.emplace(raft::neighbors::cagra::build(handle_, index_params_, dataset_view)); + } return; } template void RaftCagra::set_search_param(const AnnSearchParam& param) { + auto search_param = dynamic_cast(param); + search_params_ = search_param.p; return; } template void RaftCagra::save(const std::string& file) const { - raft::neighbors::experimental::cagra::serialize(handle_, file, *index_); + raft::neighbors::cagra::serialize(handle_, file, *index_); return; } template void RaftCagra::load(const std::string& file) { - index_ = raft::neighbors::experimental::cagra::deserialize(handle_, file); + index_ = raft::neighbors::cagra::deserialize(handle_, file); return; } @@ -146,11 +154,8 @@ void RaftCagra::search( auto neighbors_view = raft::make_device_matrix_view(neighbors_IdxT, batch_size, k); auto distances_view = raft::make_device_matrix_view(distances, batch_size, k); - raft::neighbors::experimental::cagra::search_params search_params; - search_params.max_queries = batch_size; - search_params.itopk_size = search_params_.max_queries; - raft::neighbors::experimental::cagra::search( - handle_, search_params, *index_, queries_view, neighbors_view, distances_view); + raft::neighbors::cagra::search( + handle_, search_params_, *index_, queries_view, neighbors_view, distances_view); if (!std::is_same::value) { raft::linalg::unaryOp(neighbors, diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index 5c7d1d1eae..7d791e6d29 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -73,7 +73,7 @@ class RaftIvfPQ : public ANN { AlgoProperty property; property.dataset_memory_type = MemoryType::Host; property.query_memory_type = MemoryType::Device; - property.need_dataset_when_search = true; // actually it is only used during refinement + property.need_dataset_when_search = refine_ratio_ > 1.0; return property; } void save(const std::string& file) const override; diff --git a/cpp/bench/prims/neighbors/cagra_bench.cuh b/cpp/bench/prims/neighbors/cagra_bench.cuh index c361dc82dc..dbc6645f92 100644 --- a/cpp/bench/prims/neighbors/cagra_bench.cuh +++ b/cpp/bench/prims/neighbors/cagra_bench.cuh @@ -74,13 +74,13 @@ struct CagraBench : public fixture { auto metric = raft::distance::DistanceType::L2Expanded; - index_.emplace(raft::neighbors::experimental::cagra::index( + index_.emplace(raft::neighbors::cagra::index( handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view()))); } void run_benchmark(::benchmark::State& state) override { - raft::neighbors::experimental::cagra::search_params search_params; + raft::neighbors::cagra::search_params search_params; search_params.max_queries = 1024; search_params.itopk_size = params_.itopk_size; search_params.team_size = 0; @@ -96,7 +96,7 @@ struct CagraBench : public fixture { auto queries_v = make_const_mdspan(queries_.view()); loop_on_state(state, [&]() { - raft::neighbors::experimental::cagra::search( + raft::neighbors::cagra::search( this->handle, search_params, *this->index_, queries_v, ind_v, dist_v); }); @@ -124,7 +124,7 @@ struct CagraBench : public fixture { private: const params params_; - std::optional> index_; + std::optional> index_; raft::device_matrix queries_; raft::device_matrix dataset_; raft::device_matrix knn_graph_; diff --git a/cpp/include/raft/cluster/detail/mst.cuh b/cpp/include/raft/cluster/detail/mst.cuh index c4dd74f255..a962d4b7c6 100644 --- a/cpp/include/raft/cluster/detail/mst.cuh +++ b/cpp/include/raft/cluster/detail/mst.cuh @@ -20,7 +20,7 @@ #include #include -#include +#include #include #include #include @@ -81,8 +81,20 @@ void connect_knn_graph( raft::sparse::COO connected_edges(stream); - raft::sparse::neighbors::connect_components( - handle, connected_edges, X, color, m, n, reduction_op); + // default row and column batch sizes are chosen for computing cross component nearest neighbors. + // Reference: PR #1445 + static constexpr size_t default_row_batch_size = 4096; + static constexpr size_t default_col_batch_size = 16; + + raft::sparse::neighbors::cross_component_nn(handle, + connected_edges, + X, + color, + m, + n, + reduction_op, + min(m, default_row_batch_size), + min(n, default_col_batch_size)); rmm::device_uvector indptr2(m + 1, stream); raft::sparse::convert::sorted_coo_to_csr( @@ -192,4 +204,4 @@ void build_sorted_mst( raft::copy_async(mst_weight, mst_coo.weights.data(), mst_coo.n_edges, stream); } -}; // namespace raft::cluster::detail +}; // namespace raft::cluster::detail \ No newline at end of file diff --git a/cpp/include/raft/cluster/detail/single_linkage.cuh b/cpp/include/raft/cluster/detail/single_linkage.cuh index ddd422a89b..848ca0357e 100644 --- a/cpp/include/raft/cluster/detail/single_linkage.cuh +++ b/cpp/include/raft/cluster/detail/single_linkage.cuh @@ -81,7 +81,7 @@ void single_linkage(raft::resources const& handle, * 2. Construct MST, sorted by weights */ rmm::device_uvector color(m, stream); - raft::sparse::neighbors::FixConnectivitiesRedOp op(color.data(), m); + raft::sparse::neighbors::FixConnectivitiesRedOp op(m); detail::build_sorted_mst(handle, X, indptr.data(), diff --git a/cpp/include/raft/comms/detail/std_comms.hpp b/cpp/include/raft/comms/detail/std_comms.hpp index 8b92ed48f7..de2a7d3415 100644 --- a/cpp/include/raft/comms/detail/std_comms.hpp +++ b/cpp/include/raft/comms/detail/std_comms.hpp @@ -28,6 +28,8 @@ #include +#include + #include #include @@ -138,50 +140,39 @@ class std_comms : public comms_iface { update_host(h_colors.data(), d_colors.data(), get_size(), stream_); update_host(h_keys.data(), d_keys.data(), get_size(), stream_); - RAFT_CUDA_TRY(cudaStreamSynchronize(stream_)); - - std::vector subcomm_ranks{}; - std::vector new_ucx_ptrs{}; + this->sync_stream(stream_); - for (int i = 0; i < get_size(); ++i) { - if (h_colors[i] == color) { - subcomm_ranks.push_back(i); - if (ucp_worker_ != nullptr && subcomms_ucp_) { new_ucx_ptrs.push_back((*ucp_eps_)[i]); } - } - } + ncclComm_t nccl_comm; + // Create a structure to allgather... ncclUniqueId id{}; - if (get_rank() == subcomm_ranks[0]) { // root of the new subcommunicator - RAFT_NCCL_TRY(ncclGetUniqueId(&id)); - std::vector requests(subcomm_ranks.size() - 1); - for (size_t i = 1; i < subcomm_ranks.size(); ++i) { - isend(&id, sizeof(ncclUniqueId), subcomm_ranks[i], color, requests.data() + (i - 1)); - } - waitall(requests.size(), requests.data()); - } else { - request_t request{}; - irecv(&id, sizeof(ncclUniqueId), subcomm_ranks[0], color, &request); - waitall(1, &request); - } - // FIXME: this seems unnecessary, do more testing and remove this - barrier(); + rmm::device_uvector d_nccl_ids(get_size(), stream_); - ncclComm_t nccl_comm; - RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_ranks.size(), id, key)); - - if (ucp_worker_ != nullptr && subcomms_ucp_) { - auto eps_sp = std::make_shared(new_ucx_ptrs.data()); - return std::unique_ptr(new std_comms(nccl_comm, - (ucp_worker_h)ucp_worker_, - eps_sp, - subcomm_ranks.size(), - key, - stream_, - subcomms_ucp_)); - } else { - return std::unique_ptr( - new std_comms(nccl_comm, subcomm_ranks.size(), key, stream_)); - } + if (key == 0) { RAFT_NCCL_TRY(ncclGetUniqueId(&id)); } + + update_device(d_nccl_ids.data() + get_rank(), &id, 1, stream_); + + allgather(d_nccl_ids.data() + get_rank(), + d_nccl_ids.data(), + sizeof(ncclUniqueId), + datatype_t::UINT8, + stream_); + + auto offset = + std::distance(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()), + std::find_if(thrust::make_zip_iterator(h_colors.begin(), h_keys.begin()), + thrust::make_zip_iterator(h_colors.end(), h_keys.end()), + [color](auto tuple) { return thrust::get<0>(tuple) == color; })); + + auto subcomm_size = std::count(h_colors.begin(), h_colors.end(), color); + + update_host(&id, d_nccl_ids.data() + offset, 1, stream_); + + this->sync_stream(stream_); + + RAFT_NCCL_TRY(ncclCommInitRank(&nccl_comm, subcomm_size, id, key)); + + return std::unique_ptr(new std_comms(nccl_comm, subcomm_size, key, stream_)); } void barrier() const diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 7bd30e5bc6..59fcf606c8 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include @@ -135,16 +136,6 @@ void gatherImpl(const InputIteratorT in, // stencil value type typedef typename std::iterator_traits::value_type StencilValueT; - // return type of MapTransformOp, must be convertible to IndexT - typedef typename std::result_of::type MapTransformOpReturnT; - static_assert((std::is_convertible::value), - "MapTransformOp's result type must be convertible to signed integer"); - - // return type of UnaryPredicateOp, must be convertible to bool - typedef typename std::result_of::type PredicateOpReturnT; - static_assert((std::is_convertible::value), - "UnaryPredicateOp's result type must be convertible to bool type"); - IndexT len = map_length * D; constexpr int TPB = 128; const int n_sm = raft::getMultiProcessorCount(); @@ -343,6 +334,7 @@ void gather_if(const InputIteratorT in, typedef typename std::iterator_traits::value_type MapValueT; gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/matrix/detail/gather_inplace.cuh b/cpp/include/raft/matrix/detail/gather_inplace.cuh new file mode 100644 index 0000000000..cc510e068b --- /dev/null +++ b/cpp/include/raft/matrix/detail/gather_inplace.cuh @@ -0,0 +1,116 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include + +namespace raft { +namespace matrix { +namespace detail { + +template +void gatherInplaceImpl(raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + MapTransformOp transform_op, + IndexT batch_size) +{ + IndexT m = inout.extent(0); + IndexT n = inout.extent(1); + IndexT map_length = map.extent(0); + + // skip in case of 0 length input + if (map_length <= 0 || m <= 0 || n <= 0 || batch_size < 0) return; + + RAFT_EXPECTS(map_length <= m, "Length of map should be <= number of rows for inplace gather"); + + RAFT_EXPECTS(batch_size >= 0, "batch size should be >= 0"); + + // re-assign batch_size for default case + if (batch_size == 0 || batch_size > n) batch_size = n; + + auto exec_policy = resource::get_thrust_policy(handle); + + IndexT n_batches = raft::ceildiv(n, batch_size); + + auto scratch_space = raft::make_device_vector(handle, map_length * batch_size); + + for (IndexT bid = 0; bid < n_batches; bid++) { + IndexT batch_offset = bid * batch_size; + IndexT cols_per_batch = min(batch_size, n - batch_offset); + + auto gather_op = [inout = inout.data_handle(), + map = map.data_handle(), + transform_op, + batch_offset, + map_length, + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + n] __device__(auto idx) { + IndexT row = idx / cols_per_batch; + IndexT col = idx % cols_per_batch; + MapT map_val = map[row]; + + IndexT i_src = transform_op(map_val); + return inout[i_src * n + batch_offset + col]; + }; + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(scratch_space.data_handle(), map_length * cols_per_batch), + gather_op); + + auto copy_op = [inout = inout.data_handle(), + map = map.data_handle(), + scratch_space = scratch_space.data_handle(), + batch_offset, + map_length, + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + n] __device__(auto idx) { + IndexT row = idx / cols_per_batch; + IndexT col = idx % cols_per_batch; + inout[row * n + batch_offset + col] = scratch_space[idx]; + return; + }; + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(exec_policy, counting, counting + map_length * cols_per_batch, copy_op); + } +} + +template +void gather(raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + MapTransformOp transform_op, + IndexT batch_size) +{ + gatherInplaceImpl(handle, inout, map, transform_op, batch_size); +} + +template +void gather(raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + IndexT batch_size) +{ + gatherInplaceImpl(handle, inout, map, raft::identity_op(), batch_size); +} + +} // namespace detail +} // namespace matrix +} // namespace raft \ No newline at end of file diff --git a/cpp/include/raft/matrix/detail/scatter_inplace.cuh b/cpp/include/raft/matrix/detail/scatter_inplace.cuh new file mode 100644 index 0000000000..3a57c5478b --- /dev/null +++ b/cpp/include/raft/matrix/detail/scatter_inplace.cuh @@ -0,0 +1,127 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +namespace raft { +namespace matrix { +namespace detail { + +/** + * @brief In-place scatter elements in a row-major matrix according to a + * map. The length of the map is equal to the number of rows. The + * map specifies the destination index for each row, i.e. in the + * resulting matrix, row map[i] is assigned to row i. For example, + * the matrix [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with the map [2, 0, 1] will + * be transformed to [[4, 5, 6], [7, 8, 9], [1, 2, 3]]. Batching is done on + * columns and an additional scratch space of shape n_rows * cols_batch_size + * is created. For each batch, chunks of columns from each row are copied + * into the appropriate location in the scratch space and copied back to + * the corresponding locations in the input matrix. + * + * @tparam InputIteratorT + * @tparam MapIteratorT + * @tparam IndexT + * + * @param[inout] handle raft handle + * @param[inout] inout input matrix (n_rows * n_cols) + * @param[inout] map map containing the destination index for each row (n_rows) + * @param[inout] batch_size column batch size + */ + +template +void scatterInplaceImpl( + raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + IndexT batch_size) +{ + IndexT m = inout.extent(0); + IndexT n = inout.extent(1); + IndexT map_length = map.extent(0); + + // skip in case of 0 length input + if (map_length <= 0 || m <= 0 || n <= 0 || batch_size < 0) return; + + RAFT_EXPECTS(map_length == m, + "Length of map should be equal to number of rows for inplace scatter"); + + RAFT_EXPECTS(batch_size >= 0, "batch size should be >= 0"); + + // re-assign batch_size for default case + if (batch_size == 0 || batch_size > n) batch_size = n; + + auto exec_policy = resource::get_thrust_policy(handle); + + IndexT n_batches = raft::ceildiv(n, batch_size); + + auto scratch_space = raft::make_device_vector(handle, m * batch_size); + + for (IndexT bid = 0; bid < n_batches; bid++) { + IndexT batch_offset = bid * batch_size; + IndexT cols_per_batch = min(batch_size, n - batch_offset); + + auto copy_op = [inout = inout.data_handle(), + map = map.data_handle(), + batch_offset, + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + n] __device__(auto idx) { + IndexT row = idx / cols_per_batch; + IndexT col = idx % cols_per_batch; + return inout[row * n + batch_offset + col]; + }; + raft::linalg::map_offset( + handle, + raft::make_device_vector_view(scratch_space.data_handle(), m * cols_per_batch), + copy_op); + + auto scatter_op = [inout = inout.data_handle(), + map = map.data_handle(), + scratch_space = scratch_space.data_handle(), + batch_offset, + cols_per_batch = raft::util::FastIntDiv(cols_per_batch), + n] __device__(auto idx) { + IndexT row = idx / cols_per_batch; + IndexT col = idx % cols_per_batch; + IndexT map_val = map[row]; + + inout[map_val * n + batch_offset + col] = scratch_space[idx]; + return; + }; + auto counting = thrust::make_counting_iterator(0); + thrust::for_each(exec_policy, counting, counting + m * cols_per_batch, scatter_op); + } +} + +template +void scatter(raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + IndexT batch_size) +{ + scatterInplaceImpl(handle, inout, map, batch_size); +} + +} // end namespace detail +} // end namespace matrix +} // end namespace raft \ No newline at end of file diff --git a/cpp/include/raft/matrix/gather.cuh b/cpp/include/raft/matrix/gather.cuh index 89950c2e14..2fbbcfa2bb 100644 --- a/cpp/include/raft/matrix/gather.cuh +++ b/cpp/include/raft/matrix/gather.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include namespace raft::matrix { @@ -289,6 +290,46 @@ void gather_if(const raft::resources& handle, resource::get_cuda_stream(handle)); } +/** + * @brief In-place gather elements in a row-major matrix according to a + * map. The map specifies the new order in which rows of the input matrix are + * rearranged, i.e. for each output row, read the index in the input matrix + * from the map, apply a transformation to this input index if specified, and copy the row. + * map[i]. For example, the matrix [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with the + * map [2, 0, 1] will be transformed to [[7, 8, 9], [1, 2, 3], [4, 5, 6]]. + * Batching is done on columns and an additional scratch space of + * shape n_rows * cols_batch_size is created. For each batch, chunks + * of columns from each row are copied into the appropriate location + * in the scratch space and copied back to the corresponding locations + * in the input matrix. + * + * @tparam matrix_t Matrix element type + * @tparam map_t Integer type of map elements + * @tparam map_xform_t Unary lambda expression or operator type. MapTransformOp's result type must + * be convertible to idx_t. + * @tparam idx_t Integer type used for indexing + * + * @param[in] handle raft handle + * @param[inout] inout input matrix (n_rows * n_cols) + * @param[in] map Pointer to the input sequence of gather locations + * @param[in] col_batch_size (optional) column batch size. Determines the shape of the scratch space + * (map_length, col_batch_size). When set to zero (default), no batching is done and an additional + * scratch space of shape (map_lengthm, n_cols) is created. + * @param[in] transform_op (optional) Transformation to apply to map values + */ +template +void gather(raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + idx_t col_batch_size = 0, + map_xform_t transform_op = raft::identity_op()) +{ + detail::gather(handle, inout, map, transform_op, col_batch_size); +} + /** @} */ // end of group matrix_gather } // namespace raft::matrix diff --git a/cpp/include/raft/matrix/scatter.cuh b/cpp/include/raft/matrix/scatter.cuh new file mode 100644 index 0000000000..cd2d76a863 --- /dev/null +++ b/cpp/include/raft/matrix/scatter.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +namespace raft::matrix { +/** + * @brief In-place scatter elements in a row-major matrix according to a + * map. The map specifies the new order in which rows of the input matrix are + * rearranged, i.e. read the destination index from the map, and copy the row. For example, + * the matrix [[1, 2, 3], [4, 5, 6], [7, 8, 9]] with the map [2, 0, 1] will + * be transformed to [[4, 5, 6], [7, 8, 9], [1, 2, 3]]. Batching is done on + * columns and an additional scratch space of shape n_rows * cols_batch_size + * is created. For each batch, chunks of columns from each row are copied + * into the appropriate location in the scratch space and copied back to + * the corresponding locations in the input matrix. + * Note: in-place scatter is not thread safe if the values in the map are not unique. + * Users must ensure that the map indices are unique and in the range [0, n_rows). + * + * @tparam matrix_t Matrix element type + * @tparam idx_t Integer type used for indexing + * + * @param[in] handle raft handle + * @param[inout] inout input matrix (n_rows * n_cols) + * @param[in] map Pointer to the input sequence of scatter locations. The length of the map should + * be equal to the number of rows in the input matrix. Map indices should be unique and in the range + * [0, n_rows). The map represents a complete permutation of indices. + * @param[in] col_batch_size (optional) column batch size. Determines the shape of the scratch space + * (n_rows, col_batch_size). When set to zero (default), no batching is done and an additional + * scratch space of shape (n_rows, n_cols) is created. + */ +template +void scatter(raft::resources const& handle, + raft::device_matrix_view inout, + raft::device_vector_view map, + idx_t col_batch_size = 0) +{ + detail::scatter(handle, inout, map, col_batch_size); +} + +} // namespace raft::matrix \ No newline at end of file diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 5934f6ef69..1fe55715b1 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -27,7 +27,7 @@ #include #include -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { /** * @defgroup cagra CUDA ANN Graph-based nearest neighbor search @@ -91,7 +91,7 @@ void build_knn_graph(raft::resources const& res, auto dataset_internal = mdspan, row_major, accessor>( dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - detail::build_knn_graph( + cagra::detail::build_knn_graph( res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params); } @@ -149,7 +149,7 @@ void sort_knn_graph(raft::resources const& res, auto dataset_internal = mdspan, row_major, d_accessor>( dataset.data_handle(), dataset.extent(0), dataset.extent(1)); - detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal); + cagra::detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal); } /** @@ -188,7 +188,7 @@ void optimize(raft::resources const& res, knn_graph.extent(0), knn_graph.extent(1)); - detail::graph::optimize(res, knn_graph_internal, new_graph_internal); + cagra::detail::graph::optimize(res, knn_graph_internal, new_graph_internal); } /** @@ -312,9 +312,18 @@ void search(raft::resources const& res, auto distances_internal = raft::make_device_matrix_view( distances.data_handle(), distances.extent(0), distances.extent(1)); - detail::search_main( + cagra::detail::search_main( res, params, idx, queries_internal, neighbors_internal, distances_internal); } /** @} */ // end group cagra +} // namespace raft::neighbors::cagra + +// TODO: Remove deprecated experimental namespace in 23.12 release +namespace raft::neighbors::experimental::cagra { +using raft::neighbors::cagra::build; +using raft::neighbors::cagra::build_knn_graph; +using raft::neighbors::cagra::optimize; +using raft::neighbors::cagra::search; +using raft::neighbors::cagra::sort_knn_graph; } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index 8d1771a301..2242629409 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -18,7 +18,7 @@ #include "detail/cagra/cagra_serialize.cuh" -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { /** * \defgroup cagra_serialize CAGRA Serialize @@ -110,7 +110,7 @@ void serialize(raft::resources const& handle, * @param[in] handle the raft handle * @param[in] is input stream * - * @return raft::neighbors::cagra::index + * @return raft::neighbors::experimental::cagra::index */ template index deserialize(raft::resources const& handle, std::istream& is) @@ -141,7 +141,7 @@ index deserialize(raft::resources const& handle, std::istream& is) * @param[in] handle the raft handle * @param[in] filename the name of the file that stores the index * - * @return raft::neighbors::cagra::index + * @return raft::neighbors::experimental::cagra::index */ template index deserialize(raft::resources const& handle, const std::string& filename) @@ -151,4 +151,11 @@ index deserialize(raft::resources const& handle, const std::string& fil /**@}*/ -} // namespace raft::neighbors::experimental::cagra +} // namespace raft::neighbors::cagra + +// TODO: Remove deprecated experimental namespace in 23.12 release +namespace raft::neighbors::experimental::cagra { +using raft::neighbors::cagra::deserialize; +using raft::neighbors::cagra::serialize; + +} // namespace raft::neighbors::experimental::cagra \ No newline at end of file diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 44375c01f0..16be004993 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -35,7 +35,7 @@ #include #include -namespace raft::neighbors::experimental::cagra { +namespace raft::neighbors::cagra { /** * @ingroup cagra * @{ @@ -347,4 +347,13 @@ struct index : ann::index { /** @} */ +} // namespace raft::neighbors::cagra + +// TODO: Remove deprecated experimental namespace in 23.12 release +namespace raft::neighbors::experimental::cagra { +using raft::neighbors::cagra::hash_mode; +using raft::neighbors::cagra::index; +using raft::neighbors::cagra::index_params; +using raft::neighbors::cagra::search_algo; +using raft::neighbors::cagra::search_params; } // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp b/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp index 45aff99421..9fca7f8ebd 100644 --- a/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/bitonic.hpp @@ -18,7 +18,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace bitonic { namespace detail { @@ -223,4 +223,4 @@ __device__ void warp_sort(K k[N], V v[N], const bool asc = true) } } // namespace bitonic -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 5c196471aa..2a6cedb54c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -36,7 +36,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { template void build_knn_graph(raft::resources const& res, @@ -135,6 +135,9 @@ void build_knn_graph(raft::resources const& res, resource::get_cuda_stream(res), device_memory); + size_t next_report_offset = 0; + size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps. + for (const auto& batch : vec_batches) { auto queries_view = raft::make_device_matrix_view( batch.data(), batch.size(), batch.row_width()); @@ -212,21 +215,26 @@ void build_knn_graph(raft::resources const& res, size_t num_queries_done = batch.offset() + batch.size(); const auto end_clock = std::chrono::system_clock::now(); - const auto time = - std::chrono::duration_cast(end_clock - start_clock).count() * 1e-6; - const auto throughput = num_queries_done / time; - RAFT_LOG_DEBUG( - "# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = " - "%3.2f %% \r", - num_queries_done, - dataset.extent(0), - num_queries_done / static_cast(dataset.extent(0)) * 100, - throughput, - (num_queries - num_queries_done) / throughput / 60, - static_cast(num_self_included) / num_queries_done * 100.); + if (batch.offset() > next_report_offset) { + next_report_offset += d_report_offset; + const auto time = + std::chrono::duration_cast(end_clock - start_clock).count() * + 1e-6; + const auto throughput = num_queries_done / time; + + RAFT_LOG_INFO( + "# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = " + "%3.2f %% \r", + num_queries_done, + dataset.extent(0), + num_queries_done / static_cast(dataset.extent(0)) * 100, + throughput, + (num_queries - num_queries_done) / throughput / 60, + static_cast(num_self_included) / num_queries_done * 100.); + } first = false; } if (!first) RAFT_LOG_DEBUG("# Finished building kNN graph"); } -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index 1561a3bb8d..05d8b20ebc 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -30,7 +30,7 @@ #include "search_plan.cuh" #include "search_single_cta.cuh" -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { /** * @brief Search ANN using the constructed index. @@ -133,4 +133,4 @@ void search_main(raft::resources const& res, } /** @} */ // end group cagra -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 7f708506a5..aab4709f5f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -22,7 +22,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { // Serialization version 1. constexpr int serialization_version = 2; @@ -133,4 +133,4 @@ auto deserialize(raft::resources const& res, const std::string& filename) -> ind return index; } -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index f67e110fc6..91e0d88e79 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -22,7 +22,7 @@ #include "utils.hpp" #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace device { // using LOAD_256BIT_T = ulonglong4; @@ -254,4 +254,4 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in } } // namespace device -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp index f9c81f3d25..b1a2207a4e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/device_common.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/device_common.hpp @@ -21,7 +21,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace device { // warpSize for compile time calculation @@ -49,4 +49,4 @@ _RAFT_DEVICE inline T swizzling(T x) } } // namespace device -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh index 7d4cfee0b9..625040194b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -21,7 +21,7 @@ #include "search_plan.cuh" #include "search_single_cta.cuh" -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { template class factory { @@ -86,4 +86,4 @@ class factory { } } }; -}; // namespace raft::neighbors::experimental::cagra::detail +}; // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/fragment.hpp b/cpp/include/raft/neighbors/detail/cagra/fragment.hpp index c423ac12c2..e124b3fc8c 100644 --- a/cpp/include/raft/neighbors/detail/cagra/fragment.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/fragment.hpp @@ -20,7 +20,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace device { namespace detail { @@ -208,4 +208,4 @@ _RAFT_DEVICE void print_fragment(const device::fragment& a) } } // namespace device -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index d915634df9..a47e719e02 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -36,7 +36,7 @@ #include "utils.hpp" -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace graph { // unnamed namespace to avoid multiple definition error @@ -588,4 +588,4 @@ void optimize(raft::resources const& res, } } // namespace graph -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp index 5992aaaf1d..346bbeaa9e 100644 --- a/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/hashmap.hpp @@ -23,7 +23,7 @@ // #pragma GCC diagnostic push // #pragma GCC diagnostic ignored // #pragma GCC diagnostic pop -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace hashmap { _RAFT_HOST_DEVICE inline uint32_t get_size(const uint32_t bitlen) { return 1U << bitlen; } @@ -85,4 +85,4 @@ _RAFT_DEVICE inline uint32_t insert(IdxT* const table, const uint32_t bitlen, co } } // namespace hashmap -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index bf6a32eac8..24828a8c9f 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -41,7 +41,7 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace multi_cta_search { template { }; } // namespace multi_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh index 3ccd73d92c..4640091e69 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh @@ -17,7 +17,7 @@ #include // RAFT_EXPLICIT -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace multi_cta_search { #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -93,4 +93,4 @@ instantiate_kernel_selection(32, 512, uint8_t, uint32_t, float); #undef instantiate_kernel_selection } // namespace multi_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index 43e3e83f59..7879d61007 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -40,7 +40,7 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace multi_cta_search { // #define _CLK_BREAKDOWN @@ -517,4 +517,4 @@ void select_and_run( // raft::resources const& res, } } // namespace multi_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 033022aea1..a857d335aa 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -40,7 +40,7 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace multi_kernel_search { template @@ -738,4 +738,4 @@ struct search : search_plan_impl { }; } // namespace multi_kernel_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index cbffd93caf..77e140ca3b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -26,7 +26,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { struct search_plan_impl_base : public search_params { int64_t max_dim; @@ -324,4 +324,4 @@ struct search_plan_impl : public search_plan_impl_base { // }; /** @} */ // end group cagra -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index bad2039f8c..9fc97facda 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -42,7 +42,7 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace single_cta_search { template { }; } // namespace single_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh index b0130e45d4..f589fd4637 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh @@ -16,7 +16,7 @@ #pragma once #include // RAFT_EXPLICIT -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace single_cta_search { #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -98,4 +98,4 @@ instantiate_single_cta_select_and_run(32, 512, uint8_t, uint32_t, float); #undef instantiate_single_cta_select_and_run } // namespace single_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index ca2166ab8d..df822c0113 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -41,7 +41,7 @@ #include #include // RAFT_CUDA_TRY_NOT_THROW is used TODO(tfeher): consider moving this to cuda_rt_essentials.hpp -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace single_cta_search { // #define _CLK_BREAKDOWN @@ -887,4 +887,4 @@ void select_and_run( // raft::resources const& res, RAFT_CUDA_TRY(cudaPeekAtLastError()); } } // namespace single_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh index d151cc8ee7..a1b7f930d3 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_by_radix.cuh @@ -17,7 +17,7 @@ #include "topk_for_cagra/topk_core.cuh" -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace single_cta_search { template @@ -94,4 +94,4 @@ TOP_FUNC_PARTIAL_SPECIALIZATION(512); TOP_FUNC_PARTIAL_SPECIALIZATION(1024); } // namespace single_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h index 2896dba1f3..92b9474047 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk.h @@ -18,7 +18,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { // size_t _cuann_find_topk_bufferSize(uint32_t topK, @@ -55,4 +55,4 @@ CUDA_DEVICE_HOST_FUNC inline size_t _cuann_aligned(size_t size, size_t unit = 12 if (size % unit) { size += unit - (size % unit); } return size; } -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh index 9faf57c0f5..dd73558f86 100644 --- a/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/topk_for_cagra/topk_core.cuh @@ -21,7 +21,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { using namespace cub; // @@ -927,4 +927,4 @@ inline void _cuann_find_topk(uint32_t topK, return; } -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 934e84d4d5..22c7a60647 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -22,7 +22,7 @@ #include #include -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace utils { template inline cudaDataType_t get_cuda_data_type(); @@ -150,4 +150,4 @@ struct gen_index_msb_1_mask { }; } // namespace utils -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/sparse/neighbors/connect_components.cuh b/cpp/include/raft/sparse/neighbors/cross_component_nn.cuh similarity index 65% rename from cpp/include/raft/sparse/neighbors/connect_components.cuh rename to cpp/include/raft/sparse/neighbors/cross_component_nn.cuh index fcc6ba349b..c94c6254c3 100644 --- a/cpp/include/raft/sparse/neighbors/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/cross_component_nn.cuh @@ -19,7 +19,7 @@ #include #include #include -#include +#include namespace raft::sparse::neighbors { @@ -59,11 +59,20 @@ value_idx get_n_components(value_idx* colors, size_t n_rows, cudaStream_t stream * @param[in] orig_colors array containing component number for each row of X * @param[in] n_rows number of rows in X * @param[in] n_cols number of cols in X - * @param[in] reduction_op - * @param[in] metric + * @param[in] reduction_op reduction operation for computing nearest neighbors. The reduction + * operation must have `gather` and `scatter` functions defined + * @param[in] row_batch_size the batch size for computing nearest neighbors. This parameter controls + * the number of samples for which the nearest neighbors are computed at once. Therefore, it affects + * the memory consumption mainly by reducing the size of the adjacency matrix for masked nearest + * neighbors computation + * @param[in] col_batch_size the input data is sorted and 'unsorted' based on color. An additional + * scratch space buffer of shape (n_rows, col_batch_size) is created for this. Usually, this + * parameter affects the memory consumption more drastically than the row_batch_size with a marginal + * increase in compute time as the col_batch_size is reduced + * @param[in] metric distance metric */ template -void connect_components( +void cross_component_nn( raft::resources const& handle, raft::sparse::COO& out, const value_t* X, @@ -71,9 +80,20 @@ void connect_components( size_t n_rows, size_t n_cols, red_op reduction_op, + size_t row_batch_size = 0, + size_t col_batch_size = 0, raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) { - detail::connect_components(handle, out, X, orig_colors, n_rows, n_cols, reduction_op, metric); + detail::cross_component_nn(handle, + out, + X, + orig_colors, + n_rows, + n_cols, + reduction_op, + row_batch_size, + col_batch_size, + metric); } }; // end namespace raft::sparse::neighbors \ No newline at end of file diff --git a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh b/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh similarity index 56% rename from cpp/include/raft/sparse/neighbors/detail/connect_components.cuh rename to cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh index f089cbea83..3570be2b5c 100644 --- a/cpp/include/raft/sparse/neighbors/detail/connect_components.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/cross_component_nn.cuh @@ -15,25 +15,29 @@ */ #pragma once +#include #include +#include +#include +#include #include #include - #include -#include +#include #include +#include #include +#include +#include #include #include #include #include #include - +#include #include -#include -#include #include #include #include @@ -43,6 +47,9 @@ #include #include +#include +#include + #include #include @@ -50,26 +57,24 @@ namespace raft::sparse::neighbors::detail { /** - * Functor with reduction ops for performing fused 1-nn - * computation and guaranteeing only cross-component - * neighbors are considered. + * Base functor with reduction ops for performing masked 1-nn + * computation. * @tparam value_idx * @tparam value_t */ template struct FixConnectivitiesRedOp { - value_idx* colors; value_idx m; // default constructor for cutlass - DI FixConnectivitiesRedOp() : colors(0), m(0) {} + DI FixConnectivitiesRedOp() : m(0) {} - FixConnectivitiesRedOp(value_idx* colors_, value_idx m_) : colors(colors_), m(m_){}; + FixConnectivitiesRedOp(value_idx m_) : m(m_){}; typedef typename raft::KeyValuePair KVP; DI void operator()(value_idx rit, KVP* out, const KVP& other) const { - if (rit < m && other.value < out->value && colors[rit] != colors[other.key]) { + if (rit < m && other.value < out->value) { out->key = other.key; out->value = other.value; } @@ -77,7 +82,7 @@ struct FixConnectivitiesRedOp { DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const { - if (rit < m && a.value < b.value && colors[rit] != colors[a.key]) { + if (rit < m && a.value < b.value) { return a; } else return b; @@ -96,6 +101,13 @@ struct FixConnectivitiesRedOp { DI value_t get_value(KVP& out) const { return out.value; } DI value_t get_value(value_t& out) const { return out; } + + /** The gather and scatter ensure that operator() is still consistent after rearranging the data. + * TODO (tarang-jain): refactor cross_component_nn API to separate out the gather and scatter + * functions from the reduction op. Reference: https://github.com/rapidsai/raft/issues/1614 */ + void gather(const raft::resources& handle, value_idx* map) {} + + void scatter(const raft::resources& handle, value_idx* map) {} }; /** @@ -182,6 +194,7 @@ struct LookupColorOp { * the given array of components * @tparam value_idx * @tparam value_t + * @param[in] handle raft handle * @param[out] kvp mapping of closest neighbor vertex and distance for each vertex in the given * array of components * @param[out] nn_colors components of nearest neighbors for each vertex @@ -189,41 +202,141 @@ struct LookupColorOp { * @param[in] X original dense data * @param[in] n_rows number of rows in original dense data * @param[in] n_cols number of columns in original dense data - * @param[in] stream cuda stream for which to order cuda operations + * @param[in] row_batch_size row batch size for computing nearest neighbors + * @param[in] col_batch_size column batch size for sorting and 'unsorting' + * @param[in] reduction_op reduction operation for computing nearest neighbors */ template -void perform_1nn(raft::KeyValuePair* kvp, +void perform_1nn(raft::resources const& handle, + raft::KeyValuePair* kvp, value_idx* nn_colors, value_idx* colors, const value_t* X, size_t n_rows, size_t n_cols, - cudaStream_t stream, + size_t row_batch_size, + size_t col_batch_size, red_op reduction_op) { - rmm::device_uvector workspace(n_rows, stream); - rmm::device_uvector x_norm(n_rows, stream); - - raft::linalg::rowNorm(x_norm.data(), X, n_cols, n_rows, raft::linalg::L2Norm, true, stream); - - raft::distance::fusedL2NN, value_idx>( - kvp, - X, - X, - x_norm.data(), - x_norm.data(), - n_rows, - n_rows, - n_cols, - workspace.data(), - reduction_op, - reduction_op, - true, - true, - stream); + auto stream = resource::get_cuda_stream(handle); + auto exec_policy = resource::get_thrust_policy(handle); + + auto sort_plan = raft::make_device_vector(handle, (value_idx)n_rows); + raft::linalg::map_offset(handle, sort_plan.view(), [] __device__(value_idx idx) { return idx; }); + + thrust::sort_by_key( + resource::get_thrust_policy(handle), colors, colors + n_rows, sort_plan.data_handle()); + + // Modify the reduction operation based on the sort plan. + reduction_op.gather(handle, sort_plan.data_handle()); + + auto X_mutable_view = + raft::make_device_matrix_view(const_cast(X), n_rows, n_cols); + auto sort_plan_const_view = + raft::make_device_vector_view(sort_plan.data_handle(), n_rows); + raft::matrix::gather(handle, X_mutable_view, sort_plan_const_view, (value_idx)col_batch_size); + + // Get the number of unique components from the array of colors + value_idx n_components = get_n_components(colors, n_rows, stream); + + // colors_group_idxs is an array containing the *end* indices of each color + // component in colors. That is, the value of colors_group_idxs[j] indicates + // the start of color j + 1, i.e., it is the inclusive scan of the sizes of + // the color components. + auto colors_group_idxs = raft::make_device_vector(handle, n_components + 1); + raft::sparse::convert::sorted_coo_to_csr( + colors, n_rows, colors_group_idxs.data_handle(), n_components + 1, stream); + + auto group_idxs_view = raft::make_device_vector_view( + colors_group_idxs.data_handle() + 1, n_components); + + auto x_norm = raft::make_device_vector(handle, (value_idx)n_rows); + raft::linalg::rowNorm( + x_norm.data_handle(), X, n_cols, n_rows, raft::linalg::L2Norm, true, stream); + + auto adj = raft::make_device_matrix(handle, row_batch_size, n_components); + using OutT = raft::KeyValuePair; + using ParamT = raft::distance::masked_l2_nn_params; + + bool apply_sqrt = true; + bool init_out_buffer = true; + ParamT params{reduction_op, reduction_op, apply_sqrt, init_out_buffer}; + + auto X_full_view = raft::make_device_matrix_view(X, n_rows, n_cols); + + size_t n_batches = raft::ceildiv(n_rows, row_batch_size); + + for (size_t bid = 0; bid < n_batches; bid++) { + size_t batch_offset = bid * row_batch_size; + size_t rows_per_batch = min(row_batch_size, n_rows - batch_offset); + + auto X_batch_view = raft::make_device_matrix_view( + X + batch_offset * n_cols, rows_per_batch, n_cols); + + auto x_norm_batch_view = raft::make_device_vector_view( + x_norm.data_handle() + batch_offset, rows_per_batch); + + auto mask_op = [colors, + n_components = raft::util::FastIntDiv(n_components), + batch_offset] __device__(value_idx idx) { + value_idx row = idx / n_components; + value_idx col = idx % n_components; + return colors[batch_offset + row] != col; + }; + + auto adj_vector_view = raft::make_device_vector_view( + adj.data_handle(), rows_per_batch * n_components); + + raft::linalg::map_offset(handle, adj_vector_view, mask_op); + + auto adj_view = raft::make_device_matrix_view( + adj.data_handle(), rows_per_batch, n_components); + + auto kvp_view = + raft::make_device_vector_view, value_idx>( + kvp + batch_offset, rows_per_batch); + + raft::distance::masked_l2_nn(handle, + params, + X_batch_view, + X_full_view, + x_norm_batch_view, + x_norm.view(), + adj_view, + group_idxs_view, + kvp_view); + } + + // Transform the keys so that they correctly point to the unpermuted indices. + thrust::transform(exec_policy, + kvp, + kvp + n_rows, + kvp, + [sort_plan = sort_plan.data_handle()] __device__(OutT KVP) { + OutT res; + res.value = KVP.value; + res.key = sort_plan[KVP.key]; + return res; + }); + + // Undo permutation of the rows of X by scattering in place. + raft::matrix::scatter(handle, X_mutable_view, sort_plan_const_view, (value_idx)col_batch_size); + + // Undo permutation of the key-value pair and color vectors. This is not done + // inplace, so using two temporary vectors. + auto tmp_colors = raft::make_device_vector(handle, n_rows); + auto tmp_kvp = raft::make_device_vector(handle, n_rows); + + thrust::scatter(exec_policy, kvp, kvp + n_rows, sort_plan.data_handle(), tmp_kvp.data_handle()); + thrust::scatter( + exec_policy, colors, colors + n_rows, sort_plan.data_handle(), tmp_colors.data_handle()); + reduction_op.scatter(handle, sort_plan.data_handle()); + + raft::copy_async(colors, tmp_colors.data_handle(), n_rows, stream); + raft::copy_async(kvp, tmp_kvp.data_handle(), n_rows, stream); LookupColorOp extract_colors_op(colors); - thrust::transform(rmm::exec_policy(stream), kvp, kvp + n_rows, nn_colors, extract_colors_op); + thrust::transform(exec_policy, kvp, kvp + n_rows, nn_colors, extract_colors_op); } /** @@ -239,22 +352,22 @@ void perform_1nn(raft::KeyValuePair* kvp, * @param stream stream for which to order CUDA operations */ template -void sort_by_color(value_idx* colors, +void sort_by_color(raft::resources const& handle, + value_idx* colors, value_idx* nn_colors, raft::KeyValuePair* kvp, value_idx* src_indices, - size_t n_rows, - cudaStream_t stream) + size_t n_rows) { + auto exec_policy = resource::get_thrust_policy(handle); thrust::counting_iterator arg_sort_iter(0); - thrust::copy(rmm::exec_policy(stream), arg_sort_iter, arg_sort_iter + n_rows, src_indices); + thrust::copy(exec_policy, arg_sort_iter, arg_sort_iter + n_rows, src_indices); auto keys = thrust::make_zip_iterator( thrust::make_tuple(colors, nn_colors, (KeyValuePair*)kvp)); auto vals = thrust::make_zip_iterator(thrust::make_tuple(src_indices)); - // get all the colors in contiguous locations so we can map them to warps. - thrust::sort_by_key(rmm::exec_policy(stream), keys, keys + n_rows, vals, TupleComp()); + thrust::sort_by_key(exec_policy, keys, keys + n_rows, vals, TupleComp()); } template @@ -285,9 +398,7 @@ __global__ void min_components_by_color_kernel(value_idx* out_rows, * @tparam value_idx * @tparam value_t * @param[out] coo output edge list - * @param[in] out_indptr output indptr for ordering edge list - * @param[in] colors_indptr indptr of source components - * @param[in] colors_nn components of nearest neighbors to each source component + * @param[in] out_index output indptr for ordering edge list * @param[in] indices indices of source vertices for each component * @param[in] kvp indices and distances of each destination vertex for each component * @param[in] n_colors number of components @@ -324,12 +435,24 @@ void min_components_by_color(raft::sparse::COO& coo, * @param[out] out output edge list containing nearest cross-component * edges. * @param[in] X original (row-major) dense matrix for which knn graph should be constructed. - * @param[in] colors array containing component number for each row of X + * @param[in] orig_colors array containing component number for each row of X * @param[in] n_rows number of rows in X * @param[in] n_cols number of cols in X + * @param[in] reduction_op reduction operation for computing nearest neighbors. The reduction + * operation must have `gather` and `scatter` functions defined + * @param[in] row_batch_size the batch size for computing nearest neighbors. This parameter controls + * the number of samples for which the nearest neighbors are computed at once. Therefore, it affects + * the memory consumption mainly by reducing the size of the adjacency matrix for masked nearest + * neighbors computation. default 0 indicates that no batching is done + * @param[in] col_batch_size the input data is sorted and 'unsorted' based on color. An additional + * scratch space buffer of shape (n_rows, col_batch_size) is created for this. Usually, this + * parameter affects the memory consumption more drastically than the col_batch_size with a marginal + * increase in compute time as the col_batch_size is reduced. default 0 indicates that no batching + * is done + * @param[in] metric distance metric */ template -void connect_components( +void cross_component_nn( raft::resources const& handle, raft::sparse::COO& out, const value_t* X, @@ -337,6 +460,8 @@ void connect_components( size_t n_rows, size_t n_cols, red_op reduction_op, + size_t row_batch_size, + size_t col_batch_size, raft::distance::DistanceType metric = raft::distance::DistanceType::L2SqrtExpanded) { auto stream = resource::get_cuda_stream(handle); @@ -345,13 +470,16 @@ void connect_components( "Fixing connectivities for an unconnected k-NN graph only " "supports L2SqrtExpanded currently."); + if (row_batch_size == 0 || row_batch_size > n_rows) { row_batch_size = n_rows; } + + if (col_batch_size == 0 || col_batch_size > n_cols) { col_batch_size = n_cols; } + rmm::device_uvector colors(n_rows, stream); - raft::copy_async(colors.data(), orig_colors, n_rows, stream); // Normalize colors so they are drawn from a monotonically increasing set - raft::label::make_monotonic(colors.data(), colors.data(), n_rows, stream, true); - - value_idx n_components = get_n_components(colors.data(), n_rows, stream); + constexpr bool zero_based = true; + raft::label::make_monotonic( + colors.data(), const_cast(orig_colors), n_rows, stream, zero_based); /** * First compute 1-nn for all colors where the color of each data point @@ -361,13 +489,15 @@ void connect_components( rmm::device_uvector> temp_inds_dists(n_rows, stream); rmm::device_uvector src_indices(n_rows, stream); - perform_1nn(temp_inds_dists.data(), + perform_1nn(handle, + temp_inds_dists.data(), nn_colors.data(), colors.data(), X, n_rows, n_cols, - stream, + row_batch_size, + col_batch_size, reduction_op); /** @@ -376,7 +506,7 @@ void connect_components( // max_color + 1 = number of connected components // sort nn_colors by key w/ original colors sort_by_color( - colors.data(), nn_colors.data(), temp_inds_dists.data(), src_indices.data(), n_rows, stream); + handle, colors.data(), nn_colors.data(), temp_inds_dists.data(), src_indices.data(), n_rows); /** * Take the min for any duplicate colors diff --git a/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh b/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh index 61378d71d8..00c5317b5c 100644 --- a/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh +++ b/cpp/include/raft/sparse/neighbors/detail/knn_graph.cuh @@ -126,7 +126,6 @@ void knn_graph(raft::resources const& handle, // pass value_idx through to knn. rmm::device_uvector int64_indices(nnz, stream); - uint32_t knn_start = curTimeMillis(); raft::spatial::knn::brute_force_knn(handle, inputs, sizes, diff --git a/cpp/include/raft/sparse/selection/connect_components.cuh b/cpp/include/raft/sparse/selection/cross_component_nn.cuh similarity index 87% rename from cpp/include/raft/sparse/selection/connect_components.cuh rename to cpp/include/raft/sparse/selection/cross_component_nn.cuh index 9bc3f1553a..e115d6c061 100644 --- a/cpp/include/raft/sparse/selection/connect_components.cuh +++ b/cpp/include/raft/sparse/selection/cross_component_nn.cuh @@ -19,7 +19,7 @@ */ /** - * DISCLAIMER: this file is deprecated: use connect_components.cuh instead + * DISCLAIMER: this file is deprecated: use cross_component_nn.cuh instead */ #pragma once @@ -28,10 +28,10 @@ " is deprecated and will be removed in a future release." \ " Please use the sparse/spatial version instead.") -#include +#include namespace raft::linkage { -using raft::sparse::neighbors::connect_components; +using raft::sparse::neighbors::cross_component_nn; using raft::sparse::neighbors::FixConnectivitiesRedOp; using raft::sparse::neighbors::get_n_components; } // namespace raft::linkage \ No newline at end of file diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 850b741dfd..1ce041d8da 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -466,7 +466,7 @@ struct batch_load_iterator { if (source_ == nullptr) { return; } if (needs_copy_) { if (size() > 0) { - RAFT_LOG_DEBUG("batch_load_iterator::copy(offset = %zu, size = %zu, row_width = %zu)", + RAFT_LOG_TRACE("batch_load_iterator::copy(offset = %zu, size = %zu, row_width = %zu)", size_t(offset()), size_t(size()), size_t(row_width())); diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py index 170c57c521..32bec82d38 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_00_generate.py @@ -40,7 +40,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \\ template void select_and_run( \\ @@ -73,7 +73,7 @@ trailer = """ #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::namespace multi_cta_search +} // namespace raft::neighbors::cagra::detail::namespace multi_cta_search """ mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu index 207028dcec..7536c8e9d5 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 1024, float, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu index 4a5c0f106b..96b4dab650 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(8, 128, float, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu index 93a9f41881..410ac55f66 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(16, 256, float, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu index fb321b2cf7..f80b8603d1 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 512, float, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu index e73698460d..97f29da4c2 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 1024, float, uint64_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu index e51fdcbc62..959e36ed7e 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(8, 128, float, uint64_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu index caa45b5395..4324df905b 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(16, 256, float, uint64_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu index 67e54f0937..e1d1f8fa62 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_float_uint64_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 512, float, uint64_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu index 2e929eb4f0..08fbdbcf5c 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 1024, int8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu index d3e2e78250..e4015dfbd3 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(8, 128, int8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu index 802edafdf2..22622b380a 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(16, 256, int8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu index 96e91c475e..8fcd2008c0 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 512, int8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu index 6db346c67a..feb9b01819 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 1024, uint8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu index 4b1c6c89f4..7fa2d4f1a8 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(8, 128, uint8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu index f978a9011a..8f278f9b4b 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(16, 256, uint8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu index 390330ec93..cec0753442 100644 --- a/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_multi_cta_uint8_uint32_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::multi_cta_search { +namespace raft::neighbors::cagra::detail::multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ template void select_and_run( \ @@ -58,4 +58,4 @@ instantiate_kernel_selection(32, 512, uint8_t, uint32_t, float); #undef instantiate_kernel_selection -} // namespace raft::neighbors::experimental::cagra::detail::multi_cta_search +} // namespace raft::neighbors::cagra::detail::multi_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py index b8f623d4c4..ba6ce82485 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_00_generate.py @@ -40,7 +40,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \\ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \\ @@ -75,7 +75,7 @@ trailer = """ #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search """ mxdim_team = [(128, 8), (256, 16), (512, 32), (1024, 32)] diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu index 523f2761fc..7474875bf9 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 1024, float, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu index cb8b21bfe8..8efd7ade82 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(8, 128, float, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu index f5ccfa7572..df88617904 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(16, 256, float, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu index 1d83979a88..fb9bf8b7d5 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 512, float, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu index cd588e13ef..da49fa2f4b 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 1024, float, uint64_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu index b47db68273..3c5e595329 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(8, 128, float, uint64_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu index d875080345..a32d2f4516 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(16, 256, float, uint64_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu index 848e71a645..1efcbcc125 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 512, float, uint64_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu index de7acb56fe..8e3b2ed6f6 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 1024, int8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu index d0e90603e2..ad3db811ec 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(8, 128, int8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu index 26764c5ad9..845ee65c33 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(16, 256, int8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu index 6568ab6dba..c0a237140c 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 512, int8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu index 311f42c9a7..07e678bcb5 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 1024, uint8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu index 197aa71d7b..33a956e77d 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim128_t8.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(8, 128, uint8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu index dfb47a1137..cfc4598404 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim256_t16.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(16, 256, uint8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu index 1b874bcf9b..ee4897ff3f 100644 --- a/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu +++ b/cpp/src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu @@ -26,7 +26,7 @@ #include -namespace raft::neighbors::experimental::cagra::detail::single_cta_search { +namespace raft::neighbors::cagra::detail::single_cta_search { #define instantiate_single_cta_select_and_run( \ TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -60,4 +60,4 @@ instantiate_single_cta_select_and_run(32, 512, uint8_t, uint32_t, float); #undef instantiate_single_cta_search_kernel -} // namespace raft::neighbors::experimental::cagra::detail::single_cta_search +} // namespace raft::neighbors::cagra::detail::single_cta_search diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 77f571f705..efcd48cd1d 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -13,27 +13,38 @@ # ============================================================================= # ################################################################################################## -# * compiler function ----------------------------------------------------------------------------- +# enable testing ################################################################################ +# ################################################################################################## +enable_testing() +include(rapids-test) +rapids_test_init() function(ConfigureTest) set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) - set(oneValueArgs NAME) + set(oneValueArgs NAME GPUS PERCENT) set(multiValueArgs PATH TARGETS CONFIGURATIONS) - cmake_parse_arguments(ConfigureTest "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - - set(TEST_NAME ${ConfigureTest_NAME}) - - add_executable(${TEST_NAME} ${ConfigureTest_PATH}) + cmake_parse_arguments(_RAFT_TEST "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + if(NOT DEFINED _RAFT_TEST_GPUS AND NOT DEFINED _RAFT_TEST_PERCENT) + set(_RAFT_TEST_GPUS 1) + set(_RAFT_TEST_PERCENT 30) + endif() + if(NOT DEFINED _RAFT_TEST_GPUS) + set(_RAFT_TEST_GPUS 1) + endif() + if(NOT DEFINED _RAFT_TEST_PERCENT) + set(_RAFT_TEST_PERCENT 100) + endif() - message("TEST PATH: ${ConfigureTest_PATH}") + set(TEST_NAME ${_RAFT_TEST_NAME}) + add_executable(${TEST_NAME} ${_RAFT_TEST_PATH}) target_link_libraries( ${TEST_NAME} PRIVATE raft raft_internal - $<$:raft::compiled> + $<$:raft::compiled> GTest::gtest GTest::gtest_main Threads::Threads @@ -41,35 +52,31 @@ function(ConfigureTest) $ $ ) - - add_test(NAME ${TEST_NAME} COMMAND ${TEST_NAME}) - set_target_properties( ${TEST_NAME} - PROPERTIES # set target compile options + PROPERTIES RUNTIME_OUTPUT_DIRECTORY "$" INSTALL_RPATH "\$ORIGIN/../../../lib" CXX_STANDARD 17 CXX_STANDARD_REQUIRED ON CUDA_STANDARD 17 CUDA_STANDARD_REQUIRED ON ) - target_compile_options( ${TEST_NAME} PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" ) - - if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + if(_RAFT_TEST_EXPLICIT_INSTANTIATE_ONLY) target_compile_definitions(${TEST_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") endif() target_include_directories(${TEST_NAME} PUBLIC "$") - install( - TARGETS ${TEST_NAME} - COMPONENT testing - DESTINATION bin/gtests/libraft - EXCLUDE_FROM_ALL + rapids_test_add( + NAME ${TEST_NAME} + COMMAND ${TEST_NAME} + GPUS ${_RAFT_TEST_GPUS} + PERCENT ${_RAFT_TEST_PERCENT} + INSTALL_COMPONENT_SET testing ) endfunction() @@ -90,7 +97,6 @@ if(BUILD_TESTS) test/cluster/cluster_solvers.cu test/cluster/linkage.cu test/cluster/kmeans_find_k.cu - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) @@ -117,7 +123,6 @@ if(BUILD_TESTS) test/core/span.cu test/core/temporary_device_buffer.cu test/test.cpp - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) @@ -147,7 +152,6 @@ if(BUILD_TESTS) test/distance/masked_nn_compress_to_bits.cu test/distance/fused_l2_nn.cu test/distance/gram.cu - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) @@ -181,12 +185,10 @@ if(BUILD_TESTS) # * EXT_HEADERS_TEST_COMPILED_IMPLICIT: RAFT_COMPILED defined # * EXT_HEADERS_TEST_IMPLICIT: no macros defined. ConfigureTest( - NAME EXT_HEADERS_TEST_COMPILED_EXPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB + NAME EXT_HEADERS_TEST_COMPILED_EXPLICIT PATH ${EXT_HEADER_TEST_SOURCES} LIB EXPLICIT_INSTANTIATE_ONLY ) - ConfigureTest( - NAME EXT_HEADERS_TEST_COMPILED_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB - ) + ConfigureTest(NAME EXT_HEADERS_TEST_COMPILED_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES} LIB) ConfigureTest(NAME EXT_HEADERS_TEST_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES}) ConfigureTest(NAME LABEL_TEST PATH test/label/label.cu test/label/merge_labels.cu) @@ -238,21 +240,26 @@ if(BUILD_TESTS) test/matrix/columnSort.cu test/matrix/diagonal.cu test/matrix/gather.cu + test/matrix/scatter.cu test/matrix/eye.cu test/matrix/linewise_op.cu test/matrix/math.cu test/matrix/matrix.cu test/matrix/norm.cu test/matrix/reverse.cu - test/matrix/select_k.cu test/matrix/slice.cu test/matrix/triangular.cu test/sparse/spectral_matrix.cu - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) + ConfigureTest(NAME MATRIX_SELECT_TEST PATH test/matrix/select_k.cu LIB EXPLICIT_INSTANTIATE_ONLY) + + ConfigureTest( + NAME MATRIX_SELECT_LARGE_TEST PATH test/matrix/select_large_k.cu LIB EXPLICIT_INSTANTIATE_ONLY + ) + ConfigureTest( NAME RANDOM_TEST @@ -270,7 +277,7 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster/cluster_solvers_deprecated.cu test/linalg/eigen_solvers.cu - test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY + test/lap/lap.cu test/sparse/mst.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -296,17 +303,16 @@ if(BUILD_TESTS) ConfigureTest( NAME SPARSE_DIST_TEST PATH test/sparse/dist_coo_spmv.cu test/sparse/distance.cu - test/sparse/gram.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY + test/sparse/gram.cu LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( NAME SPARSE_NEIGHBORS_TEST PATH - test/sparse/neighbors/connect_components.cu + test/sparse/neighbors/cross_component_nn.cu test/sparse/neighbors/brute_force.cu test/sparse/neighbors/knn_graph.cu - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) @@ -315,6 +321,21 @@ if(BUILD_TESTS) NAME NEIGHBORS_TEST PATH + test/neighbors/knn.cu + test/neighbors/fused_l2_knn.cu + test/neighbors/tiled_knn.cu + test/neighbors/haversine.cu + test/neighbors/ball_cover.cu + test/neighbors/epsilon_neighborhood.cu + test/neighbors/refine.cu + LIB + EXPLICIT_INSTANTIATE_ONLY + ) + + ConfigureTest( + NAME + NEIGHBORS_ANN_CAGRA_TEST + PATH test/neighbors/ann_cagra/test_float_uint32_t.cu test/neighbors/ann_cagra/test_int8_t_uint32_t.cu test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu @@ -327,6 +348,18 @@ if(BUILD_TESTS) src/neighbors/detail/cagra/search_single_cta_float_uint64_dim256_t16.cu src/neighbors/detail/cagra/search_single_cta_float_uint64_dim512_t32.cu src/neighbors/detail/cagra/search_single_cta_float_uint64_dim1024_t32.cu + LIB + EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + + ConfigureTest( + NAME + NEIGHBORS_ANN_IVF_TEST + PATH test/neighbors/ann_ivf_flat/test_float_int64_t.cu test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -335,17 +368,17 @@ if(BUILD_TESTS) test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu - test/neighbors/knn.cu - test/neighbors/fused_l2_knn.cu - test/neighbors/tiled_knn.cu - test/neighbors/haversine.cu - test/neighbors/ball_cover.cu - test/neighbors/epsilon_neighborhood.cu - test/neighbors/refine.cu - test/neighbors/selection.cu - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY + GPUS + 1 + PERCENT + 100 + ) + + ConfigureTest( + NAME NEIGHBORS_SELECTION_TEST PATH test/neighbors/selection.cu LIB EXPLICIT_INSTANTIATE_ONLY + GPUS 1 PERCENT 50 ) ConfigureTest( @@ -377,7 +410,6 @@ if(BUILD_TESTS) test/stats/trustworthiness.cu test/stats/weighted_mean.cu test/stats/v_measure.cu - OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) @@ -395,3 +427,8 @@ if(BUILD_TESTS) test/util/reduction.cu ) endif() + +# ################################################################################################## +# Install tests #################################################################################### +# ################################################################################################## +rapids_test_install_relocatable(INSTALL_COMPONENT_SET testing DESTINATION bin/gtests/libraft) diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index e660dbef13..52ec2efe8e 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -14,9 +14,9 @@ * limitations under the License. */ -// XXX: We allow the instantiation of fused_l2_nn here: -// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); -// raft::linkage::connect_components( +// XXX: We allow the instantiation of masked_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(params.n_row); +// raft::linkage::cross_component_nn( // handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); // // TODO: consider adding this to libraft.so or creating an instance in a diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index b3640a888a..d5fecd93c6 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -75,9 +75,14 @@ template class GramMatrixTest : public ::testing::TestWithParam { protected: GramMatrixTest() - : params(GetParam()), stream(0), x1(0, stream), x2(0, stream), gram(0, stream), gram_host(0) + : params(GetParam()), + handle(), + x1(0, resource::get_cuda_stream(handle)), + x2(0, resource::get_cuda_stream(handle)), + gram(0, resource::get_cuda_stream(handle)), + gram_host(0) { - RAFT_CUDA_TRY(cudaStreamCreate(&stream)); + auto stream = resource::get_cuda_stream(handle); if (params.ld1 == 0) { params.ld1 = params.is_row_major ? params.n_cols : params.n1; } if (params.ld2 == 0) { params.ld2 = params.is_row_major ? params.n_cols : params.n2; } @@ -99,7 +104,7 @@ class GramMatrixTest : public ::testing::TestWithParam { r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream); } - ~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); } + ~GramMatrixTest() override {} void runTest() { @@ -127,6 +132,7 @@ class GramMatrixTest : public ::testing::TestWithParam { (*kernel)(handle, x1_span, x2_span, out_span); + auto stream = resource::get_cuda_stream(handle); naiveGramMatrixKernel(params.n1, params.n2, params.n_cols, @@ -142,16 +148,16 @@ class GramMatrixTest : public ::testing::TestWithParam { handle); ASSERT_TRUE(raft::devArrMatchHost( - gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f))); + gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f), stream)); } - raft::resources handle; - cudaStream_t stream = 0; GramMatrixInputs params; + raft::resources handle; rmm::device_uvector x1; rmm::device_uvector x2; rmm::device_uvector gram; + std::vector gram_host; }; diff --git a/cpp/test/matrix/gather.cu b/cpp/test/matrix/gather.cu index cab96576d2..b1228f05ca 100644 --- a/cpp/test/matrix/gather.cu +++ b/cpp/test/matrix/gather.cu @@ -72,10 +72,16 @@ struct GatherInputs { IdxT nrows; IdxT ncols; IdxT map_length; + IdxT col_batch_size; unsigned long long int seed; }; -template +template class GatherTest : public ::testing::TestWithParam> { protected: GatherTest() @@ -97,6 +103,8 @@ class GatherTest : public ::testing::TestWithParam> { IdxT map_length = params.map_length; IdxT len = params.nrows * params.ncols; + if (map_length > params.nrows) map_length = params.nrows; + // input matrix setup d_in.resize(params.nrows * params.ncols, stream); h_in.resize(params.nrows * params.ncols); @@ -143,6 +151,8 @@ class GatherTest : public ::testing::TestWithParam> { auto in_view = raft::make_device_matrix_view( d_in.data(), params.nrows, params.ncols); + auto inout_view = raft::make_device_matrix_view( + d_in.data(), params.nrows, params.ncols); auto out_view = raft::make_device_matrix_view( d_out_act.data(), map_length, params.ncols); auto map_view = raft::make_device_vector_view(d_map.data(), map_length); @@ -154,12 +164,23 @@ class GatherTest : public ::testing::TestWithParam> { handle, in_view, out_view, map_view, stencil_view, pred_op, transform_op); } else if (Conditional) { raft::matrix::gather_if(handle, in_view, out_view, map_view, stencil_view, pred_op); + } else if (MapTransform && Inplace) { + raft::matrix::gather(handle, inout_view, map_view, params.col_batch_size, transform_op); } else if (MapTransform) { raft::matrix::gather(handle, in_view, map_view, out_view, transform_op); + } else if (Inplace) { + raft::matrix::gather(handle, inout_view, map_view, params.col_batch_size); } else { raft::matrix::gather(handle, in_view, map_view, out_view); } + if (Inplace) { + raft::copy_async(d_out_act.data(), + d_in.data(), + map_length * params.ncols, + raft::resource::get_cuda_stream(handle)); + } + resource::sync_stream(handle, stream); } @@ -173,39 +194,53 @@ class GatherTest : public ::testing::TestWithParam> { rmm::device_uvector d_map; }; -#define GATHER_TEST(test_type, test_name, test_inputs) \ - typedef RAFT_DEPAREN(test_type) test_name; \ - TEST_P(test_name, Result) \ - { \ - ASSERT_TRUE(devArrMatch(d_out_exp.data(), \ - d_out_act.data(), \ - params.map_length* params.ncols, \ - raft::Compare())); \ - } \ +#define GATHER_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE( \ + devArrMatch(d_out_exp.data(), d_out_act.data(), d_out_exp.size(), raft::Compare())); \ + } \ INSTANTIATE_TEST_CASE_P(GatherTests, test_name, ::testing::ValuesIn(test_inputs)) -const std::vector> inputs_i32 = - raft::util::itertools::product>({25, 2000}, {6, 31, 129}, {11, 999}, {1234ULL}); +const std::vector> inputs_i32 = raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {11, 999}, {2, 3, 6}, {1234ULL}); const std::vector> inputs_i64 = raft::util::itertools::product>( - {25, 2000}, {6, 31, 129}, {11, 999}, {1234ULL}); + {25, 2000}, {6, 31, 129}, {11, 999}, {2, 3, 6}, {1234ULL}); +const std::vector> inplace_inputs_i32 = + raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {11, 999}, {0, 1, 2, 3, 6, 100}, {1234ULL}); +const std::vector> inplace_inputs_i64 = + raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {11, 999}, {0, 1, 2, 3, 6, 100}, {1234ULL}); -GATHER_TEST((GatherTest), GatherTestFU32I32, inputs_i32); -GATHER_TEST((GatherTest), +GATHER_TEST((GatherTest), GatherTestFU32I32, inputs_i32); +GATHER_TEST((GatherTest), GatherTransformTestFU32I32, inputs_i32); -GATHER_TEST((GatherTest), GatherIfTestFU32I32, inputs_i32); -GATHER_TEST((GatherTest), +GATHER_TEST((GatherTest), + GatherIfTestFU32I32, + inputs_i32); +GATHER_TEST((GatherTest), GatherIfTransformTestFU32I32, inputs_i32); -GATHER_TEST((GatherTest), +GATHER_TEST((GatherTest), GatherIfTransformTestDU32I32, inputs_i32); -GATHER_TEST((GatherTest), +GATHER_TEST((GatherTest), GatherIfTransformTestFU32I64, inputs_i64); -GATHER_TEST((GatherTest), +GATHER_TEST((GatherTest), GatherIfTransformTestFI64I64, inputs_i64); - +GATHER_TEST((GatherTest), + GatherInplaceTestFU32I32, + inplace_inputs_i32); +GATHER_TEST((GatherTest), + GatherInplaceTestFU32I64, + inplace_inputs_i64); +GATHER_TEST((GatherTest), + GatherInplaceTestFI64I64, + inplace_inputs_i64); } // end namespace raft \ No newline at end of file diff --git a/cpp/test/matrix/scatter.cu b/cpp/test/matrix/scatter.cu new file mode 100644 index 0000000000..3a1a40086e --- /dev/null +++ b/cpp/test/matrix/scatter.cu @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +namespace raft { + +template +void naiveScatter( + InputIteratorT in, IdxT D, IdxT N, MapIteratorT map, IdxT map_length, OutputIteratorT out) +{ + for (IdxT outRow = 0; outRow < map_length; ++outRow) { + typename std::iterator_traits::value_type map_val = map[outRow]; + IdxT outRowStart = map_val * D; + IdxT inRowStart = outRow * D; + for (IdxT i = 0; i < D; ++i) { + out[outRowStart + i] = in[inRowStart + i]; + } + } +} + +template +struct ScatterInputs { + IdxT nrows; + IdxT ncols; + IdxT col_batch_size; + unsigned long long int seed; +}; + +template +class ScatterTest : public ::testing::TestWithParam> { + protected: + ScatterTest() + : stream(resource::get_cuda_stream(handle)), + params(::testing::TestWithParam>::GetParam()), + d_in(0, stream), + d_out_exp(0, stream), + d_map(0, stream) + { + } + + void SetUp() override + { + raft::random::RngState r(params.seed); + raft::random::RngState r_int(params.seed); + + IdxT len = params.nrows * params.ncols; + + // input matrix setup + d_in.resize(params.nrows * params.ncols, stream); + h_in.resize(params.nrows * params.ncols); + raft::random::uniform(handle, r, d_in.data(), len, MatrixT(-1.0), MatrixT(1.0)); + raft::update_host(h_in.data(), d_in.data(), len, stream); + + // map setup + d_map.resize(params.nrows, stream); + h_map.resize(params.nrows); + + auto exec_policy = raft::resource::get_thrust_policy(handle); + + thrust::counting_iterator permute_iter(0); + thrust::copy(exec_policy, permute_iter, permute_iter + params.nrows, d_map.data()); + + thrust::default_random_engine g; + thrust::shuffle(exec_policy, d_map.data(), d_map.data() + params.nrows, g); + + raft::update_host(h_map.data(), d_map.data(), params.nrows, stream); + resource::sync_stream(handle, stream); + + // expected and actual output matrix setup + h_out.resize(params.nrows * params.ncols); + d_out_exp.resize(params.nrows * params.ncols, stream); + + // launch scatter on the host and copy the results to device + naiveScatter(h_in.data(), params.ncols, params.nrows, h_map.data(), params.nrows, h_out.data()); + raft::update_device(d_out_exp.data(), h_out.data(), params.nrows * params.ncols, stream); + + auto inout_view = raft::make_device_matrix_view( + d_in.data(), params.nrows, params.ncols); + auto map_view = raft::make_device_vector_view(d_map.data(), params.nrows); + + raft::matrix::scatter(handle, inout_view, map_view, params.col_batch_size); + resource::sync_stream(handle, stream); + } + + protected: + raft::resources handle; + cudaStream_t stream = 0; + ScatterInputs params; + std::vector h_in, h_out; + std::vector h_map; + rmm::device_uvector d_in, d_out_exp; + rmm::device_uvector d_map; +}; + +#define SCATTER_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE( \ + devArrMatch(d_in.data(), d_out_exp.data(), d_out_exp.size(), raft::Compare())); \ + } \ + INSTANTIATE_TEST_CASE_P(ScatterTests, test_name, ::testing::ValuesIn(test_inputs)) + +const std::vector> inputs_i32 = + raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {0, 1, 2, 3, 6, 100}, {1234ULL}); +const std::vector> inputs_i64 = + raft::util::itertools::product>( + {25, 2000}, {6, 31, 129}, {0, 1, 2, 3, 6, 100}, {1234ULL}); + +SCATTER_TEST((ScatterTest), ScatterTestFI32, inputs_i32); +SCATTER_TEST((ScatterTest), ScatterTestFI64, inputs_i64); +} // end namespace raft \ No newline at end of file diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 487b6d0bfd..63f020b420 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -13,357 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#include "../test_utils.cuh" -#include - -#include - -#include -#include -#include -#include - -#include - -#include -#include - -#include -#include +#include "select_k.cuh" namespace raft::matrix { -template -auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector -{ - std::vector out(batch_size * len); - auto s = rmm::cuda_stream_default; - rmm::device_uvector out_d(out.size(), s); - sparse::iota_fill(out_d.data(), IdxT(batch_size), IdxT(len), s); - update_host(out.data(), out_d.data(), out.size(), s); - s.synchronize(); - return out; -} - -template -struct io_simple { - public: - bool not_supported = false; - - io_simple(const select::params& spec, - const std::vector& in_dists, - const std::vector& out_dists, - const std::vector& out_ids) - : in_dists_(in_dists), - in_ids_(gen_simple_ids(spec.batch_size, spec.len)), - out_dists_(out_dists), - out_ids_(out_ids) - { - } - - auto get_in_dists() -> std::vector& { return in_dists_; } - auto get_in_ids() -> std::vector& { return in_ids_; } - auto get_out_dists() -> std::vector& { return out_dists_; } - auto get_out_ids() -> std::vector& { return out_ids_; } - - private: - std::vector in_dists_; - std::vector in_ids_; - std::vector out_dists_; - std::vector out_ids_; -}; - -template -struct io_computed { - public: - bool not_supported = false; - - io_computed(const select::params& spec, - const select::Algo& algo, - const std::vector& in_dists, - const std::optional>& in_ids = std::nullopt) - : in_dists_(in_dists), - in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), - out_dists_(spec.batch_size * spec.k), - out_ids_(spec.batch_size * spec.k) - { - // check if the size is supported by the algorithm - switch (algo) { - case select::Algo::kWarpAuto: - case select::Algo::kWarpImmediate: - case select::Algo::kWarpFiltered: - case select::Algo::kWarpDistributed: - case select::Algo::kWarpDistributedShm: { - if (spec.k > raft::matrix::detail::select::warpsort::kMaxCapacity) { - not_supported = true; - return; - } - } break; - default: break; - } - - resources handle{}; - auto stream = resource::get_cuda_stream(handle); - - rmm::device_uvector in_dists_d(in_dists_.size(), stream); - rmm::device_uvector in_ids_d(in_ids_.size(), stream); - rmm::device_uvector out_dists_d(out_dists_.size(), stream); - rmm::device_uvector out_ids_d(out_ids_.size(), stream); - - update_device(in_dists_d.data(), in_dists_.data(), in_dists_.size(), stream); - update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream); - - select::select_k_impl(handle, - algo, - in_dists_d.data(), - spec.use_index_input ? in_ids_d.data() : nullptr, - spec.batch_size, - spec.len, - spec.k, - out_dists_d.data(), - out_ids_d.data(), - spec.select_min); - - update_host(out_dists_.data(), out_dists_d.data(), out_dists_.size(), stream); - update_host(out_ids_.data(), out_ids_d.data(), out_ids_.size(), stream); - - interruptible::synchronize(stream); - - auto p = topk_sort_permutation(out_dists_, out_ids_, spec.k, spec.select_min); - apply_permutation(out_dists_, p); - apply_permutation(out_ids_, p); - } - - auto get_in_dists() -> std::vector& { return in_dists_; } - auto get_in_ids() -> std::vector& { return in_ids_; } - auto get_out_dists() -> std::vector& { return out_dists_; } - auto get_out_ids() -> std::vector& { return out_ids_; } - - private: - std::vector in_dists_; - std::vector in_ids_; - std::vector out_dists_; - std::vector out_ids_; - - auto topk_sort_permutation(const std::vector& vec, - const std::vector& inds, - uint32_t k, - bool select_min) -> std::vector - { - std::vector p(vec.size()); - std::iota(p.begin(), p.end(), 0); - if (select_min) { - std::sort(p.begin(), p.end(), [&vec, &inds, k](IdxT i, IdxT j) { - const IdxT ik = i / k; - const IdxT jk = j / k; - if (ik == jk) { - if (vec[i] == vec[j]) { return inds[i] < inds[j]; } - return vec[i] < vec[j]; - } - return ik < jk; - }); - } else { - std::sort(p.begin(), p.end(), [&vec, &inds, k](IdxT i, IdxT j) { - const IdxT ik = i / k; - const IdxT jk = j / k; - if (ik == jk) { - if (vec[i] == vec[j]) { return inds[i] < inds[j]; } - return vec[i] > vec[j]; - } - return ik < jk; - }); - } - return p; - } - - template - void apply_permutation(std::vector& vec, const std::vector& p) // NOLINT - { - for (auto i = IdxT(vec.size()) - 1; i > 0; i--) { - auto j = p[i]; - while (j > i) - j = p[j]; - std::swap(vec[j], vec[i]); - } - } -}; - -template - -using Params = std::tuple; - -template typename ParamsReader> -struct SelectK // NOLINT - : public testing::TestWithParam::params_t> { - const select::params spec; - const select::Algo algo; - typename ParamsReader::io_t ref; - io_computed res; - - explicit SelectK(Params::io_t> ps) - : spec(std::get<0>(ps)), - algo(std::get<1>(ps)), // NOLINT - ref(std::get<2>(ps)), // NOLINT - res(spec, algo, ref.get_in_dists(), ref.get_in_ids()) // NOLINT - { - } - - explicit SelectK(typename ParamsReader::params_t ps) - : SelectK(ParamsReader::read(ps)) - { - } - - SelectK() - : SelectK(testing::TestWithParam::params_t>::GetParam()) - { - } - - void run() - { - if (ref.not_supported || res.not_supported) { GTEST_SKIP(); } - ASSERT_TRUE(hostVecMatch(ref.get_out_dists(), res.get_out_dists(), Compare())); - - // If the dists (keys) are the same, different corresponding ids may end up in the selection due - // to non-deterministic nature of some implementations. - auto& in_ids = ref.get_in_ids(); - auto& in_dists = ref.get_in_dists(); - auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { - if (i == j) return true; - auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); - auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); - if (static_cast(ix_i) >= in_ids.size() || static_cast(ix_j) >= in_ids.size()) - return false; - auto dist_i = in_dists[ix_i]; - auto dist_j = in_dists[ix_j]; - if (dist_i == dist_j) return true; - std::cout << "ERROR: ref[" << ix_i << "] = " << dist_i << " != " - << "res[" << ix_j << "] = " << dist_j << std::endl; - return false; - }; - ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); - } -}; - -template -struct params_simple { - using io_t = io_simple; - using input_t = - std::tuple, std::vector, std::vector>; - using params_t = std::tuple; - - static auto read(params_t ps) -> Params - { - auto ins = std::get<0>(ps); - auto algo = std::get<1>(ps); - return std::make_tuple( - std::get<0>(ins), - algo, - io_simple( - std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins))); - } -}; - -auto inputs_simple_f = testing::Values( - params_simple::input_t( - {5, 5, 5, true, true}, - {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, - 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, - {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, - 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, - {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), - params_simple::input_t( - {5, 5, 3, true, true}, - {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, - 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, - {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, - {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), - params_simple::input_t( - {5, 5, 5, true, false}, - {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, - 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, - {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, - 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, - {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), - params_simple::input_t( - {5, 5, 3, true, false}, - {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, - 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, - {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, - {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), - params_simple::input_t( - {5, 7, 3, true, true}, - {5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0, - 4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2}, - {1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2}, - {4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}), - params_simple::input_t( - {1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}), - params_simple::input_t( - {1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}), - params_simple::input_t( - {1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}), - params_simple::input_t( - {1, 130, 5, false, true}, - {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, - 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, - {20, 19, 18, 17, 16}, - {129, 0, 117, 116, 115}), - params_simple::input_t( - {1, 130, 15, false, true}, - {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, - 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, - 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, - 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, - {20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, - {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105})); - -using SimpleFloatInt = SelectK; -TEST_P(SimpleFloatInt, Run) { run(); } // NOLINT -INSTANTIATE_TEST_CASE_P( // NOLINT - SelectK, - SimpleFloatInt, - testing::Combine(inputs_simple_f, - testing::Values(select::Algo::kPublicApi, - select::Algo::kRadix8bits, - select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass, - select::Algo::kWarpImmediate, - select::Algo::kWarpFiltered, - select::Algo::kWarpDistributed))); - -template -struct with_ref { - template - struct params_random { - using io_t = io_computed; - using params_t = std::tuple; - - static auto read(params_t ps) -> Params - { - auto spec = std::get<0>(ps); - auto algo = std::get<1>(ps); - std::vector dists(spec.len * spec.batch_size); - - raft::resources handle; - { - auto s = resource::get_cuda_stream(handle); - rmm::device_uvector dists_d(spec.len * spec.batch_size, s); - raft::random::RngState r(42); - normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); - update_host(dists.data(), dists_d.data(), dists_d.size(), s); - s.synchronize(); - } - - return std::make_tuple(spec, algo, io_computed(spec, RefAlgo, dists)); - } - }; -}; - auto inputs_random_longlist = testing::Values(select::params{1, 130, 15, false}, select::params{1, 128, 15, false}, select::params{20, 700, 1, true}, @@ -412,7 +65,7 @@ auto inputs_random_largesize = testing::Values(select::params{100, 100000, 1, tr select::params{1, 1000000000, 256, false, false}); auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, true}, - select::params{100, 100000, 2000, true}, + select::params{100, 100000, 2000, false}, select::params{100, 100000, 100000, true, false}, select::params{100, 100000, 2048, false}, select::params{100, 100000, 1237, true}); @@ -458,14 +111,4 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix8bits, select::Algo::kRadix11bits, select::Algo::kRadix11bitsExtraPass))); - -using ReferencedRandomFloatSizeT = - SelectK::params_random>; -TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT -INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT - ReferencedRandomFloatSizeT, - testing::Combine(inputs_random_largek, - testing::Values(select::Algo::kRadix11bits, - select::Algo::kRadix11bitsExtraPass))); - } // namespace raft::matrix diff --git a/cpp/test/matrix/select_k.cuh b/cpp/test/matrix/select_k.cuh new file mode 100644 index 0000000000..e0e0cad225 --- /dev/null +++ b/cpp/test/matrix/select_k.cuh @@ -0,0 +1,366 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "../test_utils.cuh" +#include + +#include + +#include +#include +#include +#include + +#include + +#include +#include + +#include +#include + +namespace raft::matrix { + +template +auto gen_simple_ids(uint32_t batch_size, uint32_t len) -> std::vector +{ + std::vector out(batch_size * len); + auto s = rmm::cuda_stream_default; + rmm::device_uvector out_d(out.size(), s); + sparse::iota_fill(out_d.data(), IdxT(batch_size), IdxT(len), s); + update_host(out.data(), out_d.data(), out.size(), s); + s.synchronize(); + return out; +} + +template +struct io_simple { + public: + bool not_supported = false; + + io_simple(const select::params& spec, + const std::vector& in_dists, + const std::vector& out_dists, + const std::vector& out_ids) + : in_dists_(in_dists), + in_ids_(gen_simple_ids(spec.batch_size, spec.len)), + out_dists_(out_dists), + out_ids_(out_ids) + { + } + + auto get_in_dists() -> std::vector& { return in_dists_; } + auto get_in_ids() -> std::vector& { return in_ids_; } + auto get_out_dists() -> std::vector& { return out_dists_; } + auto get_out_ids() -> std::vector& { return out_ids_; } + + private: + std::vector in_dists_; + std::vector in_ids_; + std::vector out_dists_; + std::vector out_ids_; +}; + +template +struct io_computed { + public: + bool not_supported = false; + + io_computed(const select::params& spec, + const select::Algo& algo, + const std::vector& in_dists, + const std::optional>& in_ids = std::nullopt) + : in_dists_(in_dists), + in_ids_(in_ids.value_or(gen_simple_ids(spec.batch_size, spec.len))), + out_dists_(spec.batch_size * spec.k), + out_ids_(spec.batch_size * spec.k) + { + // check if the size is supported by the algorithm + switch (algo) { + case select::Algo::kWarpAuto: + case select::Algo::kWarpImmediate: + case select::Algo::kWarpFiltered: + case select::Algo::kWarpDistributed: + case select::Algo::kWarpDistributedShm: { + if (spec.k > raft::matrix::detail::select::warpsort::kMaxCapacity) { + not_supported = true; + return; + } + } break; + default: break; + } + + resources handle{}; + auto stream = resource::get_cuda_stream(handle); + + rmm::device_uvector in_dists_d(in_dists_.size(), stream); + rmm::device_uvector in_ids_d(in_ids_.size(), stream); + rmm::device_uvector out_dists_d(out_dists_.size(), stream); + rmm::device_uvector out_ids_d(out_ids_.size(), stream); + + update_device(in_dists_d.data(), in_dists_.data(), in_dists_.size(), stream); + update_device(in_ids_d.data(), in_ids_.data(), in_ids_.size(), stream); + + select::select_k_impl(handle, + algo, + in_dists_d.data(), + spec.use_index_input ? in_ids_d.data() : nullptr, + spec.batch_size, + spec.len, + spec.k, + out_dists_d.data(), + out_ids_d.data(), + spec.select_min); + + update_host(out_dists_.data(), out_dists_d.data(), out_dists_.size(), stream); + update_host(out_ids_.data(), out_ids_d.data(), out_ids_.size(), stream); + + interruptible::synchronize(stream); + + auto p = topk_sort_permutation(out_dists_, out_ids_, spec.k, spec.select_min); + apply_permutation(out_dists_, p); + apply_permutation(out_ids_, p); + } + + auto get_in_dists() -> std::vector& { return in_dists_; } + auto get_in_ids() -> std::vector& { return in_ids_; } + auto get_out_dists() -> std::vector& { return out_dists_; } + auto get_out_ids() -> std::vector& { return out_ids_; } + + private: + std::vector in_dists_; + std::vector in_ids_; + std::vector out_dists_; + std::vector out_ids_; + + auto topk_sort_permutation(const std::vector& vec, + const std::vector& inds, + uint32_t k, + bool select_min) -> std::vector + { + std::vector p(vec.size()); + std::iota(p.begin(), p.end(), 0); + if (select_min) { + std::sort(p.begin(), p.end(), [&vec, &inds, k](IdxT i, IdxT j) { + const IdxT ik = i / k; + const IdxT jk = j / k; + if (ik == jk) { + if (vec[i] == vec[j]) { return inds[i] < inds[j]; } + return vec[i] < vec[j]; + } + return ik < jk; + }); + } else { + std::sort(p.begin(), p.end(), [&vec, &inds, k](IdxT i, IdxT j) { + const IdxT ik = i / k; + const IdxT jk = j / k; + if (ik == jk) { + if (vec[i] == vec[j]) { return inds[i] < inds[j]; } + return vec[i] > vec[j]; + } + return ik < jk; + }); + } + return p; + } + + template + void apply_permutation(std::vector& vec, const std::vector& p) // NOLINT + { + for (auto i = IdxT(vec.size()) - 1; i > 0; i--) { + auto j = p[i]; + while (j > i) + j = p[j]; + std::swap(vec[j], vec[i]); + } + } +}; + +template +using Params = std::tuple; + +template typename ParamsReader> +struct SelectK // NOLINT + : public testing::TestWithParam::params_t> { + const select::params spec; + const select::Algo algo; + typename ParamsReader::io_t ref; + io_computed res; + + explicit SelectK(Params::io_t> ps) + : spec(std::get<0>(ps)), + algo(std::get<1>(ps)), // NOLINT + ref(std::get<2>(ps)), // NOLINT + res(spec, algo, ref.get_in_dists(), ref.get_in_ids()) // NOLINT + { + } + + explicit SelectK(typename ParamsReader::params_t ps) + : SelectK(ParamsReader::read(ps)) + { + } + + SelectK() + : SelectK(testing::TestWithParam::params_t>::GetParam()) + { + } + + void run() + { + if (ref.not_supported || res.not_supported) { GTEST_SKIP(); } + ASSERT_TRUE(hostVecMatch(ref.get_out_dists(), res.get_out_dists(), Compare())); + + // If the dists (keys) are the same, different corresponding ids may end up in the selection due + // to non-deterministic nature of some implementations. + auto& in_ids = ref.get_in_ids(); + auto& in_dists = ref.get_in_dists(); + auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { + if (i == j) return true; + auto ix_i = static_cast(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); + auto ix_j = static_cast(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); + if (static_cast(ix_i) >= in_ids.size() || static_cast(ix_j) >= in_ids.size()) + return false; + auto dist_i = in_dists[ix_i]; + auto dist_j = in_dists[ix_j]; + if (dist_i == dist_j) return true; + std::cout << "ERROR: ref[" << ix_i << "] = " << dist_i << " != " + << "res[" << ix_j << "] = " << dist_j << std::endl; + return false; + }; + ASSERT_TRUE(hostVecMatch(ref.get_out_ids(), res.get_out_ids(), compare_ids)); + } +}; + +template +struct params_simple { + using io_t = io_simple; + using input_t = + std::tuple, std::vector, std::vector>; + using params_t = std::tuple; + + static auto read(params_t ps) -> Params + { + auto ins = std::get<0>(ps); + auto algo = std::get<1>(ps); + return std::make_tuple( + std::get<0>(ins), + algo, + io_simple( + std::get<0>(ins), std::get<1>(ins), std::get<2>(ins), std::get<3>(ins))); + } +}; + +auto inputs_simple_f = testing::Values( + params_simple::input_t( + {5, 5, 5, true, true}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, + {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), + params_simple::input_t( + {5, 5, 3, true, true}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, + {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), + params_simple::input_t( + {5, 5, 5, true, false}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + {1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, + 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0, 1.0, 2.0, 3.0, 4.0, 5.0}, + {4, 3, 2, 1, 0, 0, 1, 2, 3, 4, 3, 0, 1, 4, 2, 4, 2, 1, 3, 0, 0, 2, 1, 4, 3}), + params_simple::input_t( + {5, 5, 3, true, false}, + {5.0, 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 5.0, + 1.0, 4.0, 5.0, 3.0, 2.0, 4.0, 1.0, 1.0, 3.0, 2.0, 5.0, 4.0}, + {1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 1.0, 2.0, 3.0}, + {4, 3, 2, 0, 1, 2, 3, 0, 1, 4, 2, 1, 0, 2, 1}), + params_simple::input_t( + {5, 7, 3, true, true}, + {5.0, 4.0, 3.0, 2.0, 1.3, 7.5, 19.0, 9.0, 2.0, 3.0, 3.0, 5.0, 6.0, 4.0, 2.0, 3.0, 5.0, 1.0, + 4.0, 1.0, 1.0, 5.0, 7.0, 2.5, 4.0, 7.0, 8.0, 8.0, 1.0, 3.0, 2.0, 5.0, 4.0, 1.1, 1.2}, + {1.3, 2.0, 3.0, 2.0, 3.0, 3.0, 1.0, 1.0, 1.0, 2.5, 4.0, 5.0, 1.0, 1.1, 1.2}, + {4, 3, 2, 1, 2, 3, 3, 5, 6, 2, 3, 0, 0, 5, 6}), + params_simple::input_t( + {1, 7, 3, true, true}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {1.0, 1.0, 1.0}, {3, 5, 6}), + params_simple::input_t( + {1, 7, 3, false, false}, {2.0, 3.0, 5.0, 1.0, 4.0, 1.0, 1.0}, {5.0, 4.0, 3.0}, {2, 4, 1}), + params_simple::input_t( + {1, 7, 3, false, true}, {2.0, 3.0, 5.0, 9.0, 4.0, 9.0, 9.0}, {9.0, 9.0, 9.0}, {3, 5, 6}), + params_simple::input_t( + {1, 130, 5, false, true}, + {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + {20, 19, 18, 17, 16}, + {129, 0, 117, 116, 115}), + params_simple::input_t( + {1, 130, 15, false, true}, + {19, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, + 0, 1, 0, 1, 0, 1, 0, 1, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, + 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 4, 4, 2, 3, 2, 3, 2, 3, 2, 3, 2, 20}, + {20, 19, 18, 17, 16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6}, + {129, 0, 117, 116, 115, 114, 113, 112, 111, 110, 109, 108, 107, 106, 105})); + +using SimpleFloatInt = SelectK; +TEST_P(SimpleFloatInt, Run) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P( // NOLINT + SelectK, + SimpleFloatInt, + testing::Combine(inputs_simple_f, + testing::Values(select::Algo::kPublicApi, + select::Algo::kRadix8bits, + select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass, + select::Algo::kWarpImmediate, + select::Algo::kWarpFiltered, + select::Algo::kWarpDistributed))); + +template +struct with_ref { + template + struct params_random { + using io_t = io_computed; + using params_t = std::tuple; + + static auto read(params_t ps) -> Params + { + auto spec = std::get<0>(ps); + auto algo = std::get<1>(ps); + std::vector dists(spec.len * spec.batch_size); + + raft::resources handle; + { + auto s = resource::get_cuda_stream(handle); + rmm::device_uvector dists_d(spec.len * spec.batch_size, s); + raft::random::RngState r(42); + normal(handle, r, dists_d.data(), dists_d.size(), KeyT(10.0), KeyT(100.0)); + update_host(dists.data(), dists_d.data(), dists_d.size(), s); + s.synchronize(); + } + + return std::make_tuple(spec, algo, io_computed(spec, RefAlgo, dists)); + } + }; +}; + +} // namespace raft::matrix diff --git a/cpp/test/matrix/select_large_k.cu b/cpp/test/matrix/select_large_k.cu new file mode 100644 index 0000000000..2772e84eb3 --- /dev/null +++ b/cpp/test/matrix/select_large_k.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "select_k.cuh" + +namespace raft::matrix { + +auto inputs_random_largek = testing::Values(select::params{100, 100000, 1000, true}, + select::params{100, 100000, 2000, false}, + select::params{100, 100000, 100000, true, false}, + select::params{100, 100000, 2048, false}, + select::params{100, 100000, 1237, true}); + +using ReferencedRandomFloatSizeT = + SelectK::params_random>; +TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT +INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT + ReferencedRandomFloatSizeT, + testing::Combine(inputs_random_largek, + testing::Values(select::Algo::kRadix11bits, + select::Algo::kRadix11bitsExtraPass))); + +} // namespace raft::matrix diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 9969bfd7c1..b11abf13a5 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -52,9 +52,9 @@ void RandomSuffle(raft::host_matrix_view index) IdxT* const row_ptr = index.data_handle() + i * index.extent(1); for (unsigned j = 0; j < index.extent(1); j++) { // Swap two indices at random - rand = raft::neighbors::experimental::cagra::detail::device::xorshift64(rand); + rand = raft::neighbors::cagra::detail::device::xorshift64(rand); const auto i0 = rand % index.extent(1); - rand = raft::neighbors::experimental::cagra::detail::device::xorshift64(rand); + rand = raft::neighbors::cagra::detail::device::xorshift64(rand); const auto i1 = rand % index.extent(1); const auto tmp = row_ptr[i0]; diff --git a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh index 562e5ac2ca..53278b9666 100644 --- a/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh +++ b/cpp/test/neighbors/ann_cagra/search_kernel_uint64_t.cuh @@ -17,7 +17,7 @@ #include // RAFT_EXPLICIT -namespace raft::neighbors::experimental::cagra::detail { +namespace raft::neighbors::cagra::detail { namespace multi_cta_search { #define instantiate_kernel_selection(TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T) \ @@ -90,4 +90,4 @@ instantiate_single_cta_select_and_run(16, 256, float, uint64_t, float); instantiate_single_cta_select_and_run(32, 512, float, uint64_t, float); } // namespace single_cta_search -} // namespace raft::neighbors::experimental::cagra::detail +} // namespace raft::neighbors::cagra::detail diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 5d63338b45..6030e2a1a6 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -441,7 +441,7 @@ auto inputs_random_largesize = testing::Values(SelectTestSpec{100, 100000, 1, tr SelectTestSpec{1, 100000000, 256, false, false}); auto inputs_random_largek = testing::Values(SelectTestSpec{100, 100000, 1000, true}, - SelectTestSpec{100, 100000, 2000, true}, + SelectTestSpec{100, 100000, 2000, false}, SelectTestSpec{100, 100000, 100000, true, false}, SelectTestSpec{100, 100000, 2048, false}, SelectTestSpec{100, 100000, 1237, true}); @@ -482,6 +482,11 @@ INSTANTIATE_TEST_CASE_P(SelectionTest, * SelectionTest/ReferencedRandomFloatSizeT.LargeK/0 * Indicices do not match! ref[91628] = 131.359 != res[36504] = 158.438 * Actual: false (actual=36504 != expected=91628 @38999; + * + * SelectionTest/ReferencedRandomFloatSizeT.LargeK/1 + * ERROR: ref[57977] = 58.9079 != res[21973] = 54.9354 + * Actual: false (actual=21973 != expected=57977 @107999; + * */ typedef SelectionTest::params_random> ReferencedRandomFloatSizeT; diff --git a/cpp/test/random/rng_discrete.cu b/cpp/test/random/rng_discrete.cu index 799f44735e..d1293f34ea 100644 --- a/cpp/test/random/rng_discrete.cu +++ b/cpp/test/random/rng_discrete.cu @@ -193,15 +193,16 @@ const std::vector> inputs_i64 = { {1, 10000, 5, 5, GenPhilox, 1234ULL}, }; -#define RNG_DISCRETE_TEST(test_type, test_name, test_inputs) \ - typedef RAFT_DEPAREN(test_type) test_name; \ - TEST_P(test_name, Result) \ - { \ - ASSERT_TRUE(devArrMatchHost(exp_histogram.data(), \ - histogram.data(), \ - exp_histogram.size(), \ - CompareApprox(tolerance))); \ - } \ +#define RNG_DISCRETE_TEST(test_type, test_name, test_inputs) \ + typedef RAFT_DEPAREN(test_type) test_name; \ + TEST_P(test_name, Result) \ + { \ + ASSERT_TRUE(devArrMatchHost(exp_histogram.data(), \ + histogram.data(), \ + exp_histogram.size(), \ + CompareApprox(tolerance), \ + stream)); \ + } \ INSTANTIATE_TEST_CASE_P(ReduceTests, test_name, ::testing::ValuesIn(test_inputs)) RNG_DISCRETE_TEST((RngDiscreteTest), RngDiscreteTestI32FI32, inputs_i32); diff --git a/cpp/test/sparse/gram.cu b/cpp/test/sparse/gram.cu index 87cebd3519..7b4736a08c 100644 --- a/cpp/test/sparse/gram.cu +++ b/cpp/test/sparse/gram.cu @@ -157,6 +157,8 @@ class GramMatrixTest : public ::testing::TestWithParam { raft::random::Rng r(42137ULL); r.uniform(x1.data(), x1.size(), math_t(0), math_t(1), stream); r.uniform(x2.data(), x2.size(), math_t(0), math_t(1), stream); + + RAFT_CUDA_TRY(cudaStreamSynchronize(stream)); } ~GramMatrixTest() override { RAFT_CUDA_TRY_NO_THROW(cudaStreamDestroy(stream)); } @@ -204,7 +206,6 @@ class GramMatrixTest : public ::testing::TestWithParam { raft::update_device(indices, indices_host.data(), nnz, stream); raft::update_device(data, data_host.data(), nnz, stream); resource::sync_stream(handle, stream); - return nnz; } @@ -273,7 +274,9 @@ class GramMatrixTest : public ::testing::TestWithParam { (*kernel)(handle, x1_csr, x2_csr, out_span); } } - + // Something in gram is executing not on the 'stream' and therefore + // a full device sync is required + RAFT_CUDA_TRY(cudaDeviceSynchronize()); naiveGramMatrixKernel(params.n1, params.n2, params.n_cols, @@ -287,11 +290,10 @@ class GramMatrixTest : public ::testing::TestWithParam { params.kernel, stream, handle); - resource::sync_stream(handle, stream); ASSERT_TRUE(raft::devArrMatchHost( - gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f))); + gram_host.data(), gram.data(), gram.size(), raft::CompareApprox(1e-6f), stream)); } raft::resources handle; diff --git a/cpp/test/sparse/neighbors/connect_components.cu b/cpp/test/sparse/neighbors/connect_components.cu deleted file mode 100644 index 373963b653..0000000000 --- a/cpp/test/sparse/neighbors/connect_components.cu +++ /dev/null @@ -1,357 +0,0 @@ -/* - * Copyright (c) 2018-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -// XXX: We allow the instantiation of fused_l2_nn here: -// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); -// raft::linkage::connect_components( -// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); -// -// TODO: consider adding this to libraft.so or creating an instance in a -// separate translation unit for this test. -#undef RAFT_EXPLICIT_INSTANTIATE_ONLY - -#include -#include - -#include - -#include -#include -#include - -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include - -#include "../../test_utils.cuh" - -namespace raft { -namespace sparse { - -using namespace std; - -template -struct ConnectComponentsInputs { - value_idx n_row; - value_idx n_col; - std::vector data; - - int c; -}; - -template -class ConnectComponentsTest - : public ::testing::TestWithParam> { - protected: - void basicTest() - { - raft::resources handle; - - auto stream = resource::get_cuda_stream(handle); - - params = ::testing::TestWithParam>::GetParam(); - - raft::sparse::COO out_edges(resource::get_cuda_stream(handle)); - - rmm::device_uvector data(params.n_row * params.n_col, - resource::get_cuda_stream(handle)); - - raft::copy(data.data(), params.data.data(), data.size(), resource::get_cuda_stream(handle)); - - rmm::device_uvector indptr(params.n_row + 1, stream); - - /** - * 1. Construct knn graph - */ - raft::sparse::COO knn_graph_coo(stream); - - raft::sparse::neighbors::knn_graph(handle, - data.data(), - params.n_row, - params.n_col, - raft::distance::DistanceType::L2SqrtExpanded, - knn_graph_coo, - params.c); - - raft::sparse::convert::sorted_coo_to_csr( - knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), params.n_row + 1, stream); - - /** - * 2. Construct MST, sorted by weights - */ - rmm::device_uvector colors(params.n_row, stream); - - auto mst_coo = raft::mst::mst(handle, - indptr.data(), - knn_graph_coo.cols(), - knn_graph_coo.vals(), - params.n_row, - knn_graph_coo.nnz, - colors.data(), - stream, - false, - true); - - /** - * 3. connect_components to fix connectivities - */ - raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); - raft::linkage::connect_components( - handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); - - /** - * Construct final edge list - */ - rmm::device_uvector indptr2(params.n_row + 1, stream); - - raft::sparse::convert::sorted_coo_to_csr( - out_edges.rows(), out_edges.nnz, indptr2.data(), params.n_row + 1, stream); - - auto output_mst = raft::mst::mst(handle, - indptr2.data(), - out_edges.cols(), - out_edges.vals(), - params.n_row, - out_edges.nnz, - colors.data(), - stream, - false, - false); - - resource::sync_stream(handle, stream); - - // The sum of edges for both MST runs should be n_rows - 1 - final_edges = output_mst.n_edges + mst_coo.n_edges; - } - - void SetUp() override { basicTest(); } - - void TearDown() override {} - - protected: - ConnectComponentsInputs params; - - value_idx final_edges; -}; - -const std::vector> fix_conn_inputsf2 = { - // Test n_clusters == n_points - {10, - 5, - {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, 0.77782677, 0.43772379, - 0.4035871, 0.3282796, 0.47544681, 0.59862974, 0.12319357, 0.06239463, 0.28200272, 0.1345717, - 0.50498218, 0.5113505, 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, - 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, 0.84854131, 0.28890216, - 0.85267903, 0.74703138, 0.83842071, 0.34942792, 0.27864171, 0.70911132, 0.21338564, 0.32035554, - 0.73788331, 0.46926692, 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, - 0.76166195, 0.66613745}, - -1}, - // Test n_points == 100 - {100, - 10, - {6.26168372e-01, 9.30437651e-01, 6.02450208e-01, 2.73025296e-01, 9.53050619e-01, 3.32164396e-01, - 6.88942598e-01, 5.79163537e-01, 6.70341547e-01, 2.70140602e-02, 9.30429671e-01, 7.17721157e-01, - 9.89948537e-01, 7.75253347e-01, 1.34491522e-02, 2.48522428e-02, 3.51413378e-01, 7.64405834e-01, - 7.86373507e-01, 7.18748577e-01, 8.66998621e-01, 6.80316582e-01, 2.51288712e-01, 4.91078420e-01, - 3.76246281e-01, 4.86828710e-01, 5.67464772e-01, 5.30734742e-01, 8.99478296e-01, 7.66699088e-01, - 9.49339111e-01, 3.55248484e-01, 9.06046929e-01, 4.48407772e-01, 6.96395305e-01, 2.44277335e-01, - 7.74840000e-01, 5.21046603e-01, 4.66423971e-02, 5.12019638e-02, 8.95019614e-01, 5.28956953e-01, - 4.31536306e-01, 5.83857744e-01, 4.41787364e-01, 4.68656523e-01, 5.73971433e-01, 6.79989654e-01, - 3.19650588e-01, 6.12579596e-01, 6.49126442e-02, 8.39131142e-01, 2.85252117e-01, 5.84848929e-01, - 9.46507115e-01, 8.58440748e-01, 3.61528940e-01, 2.44215959e-01, 3.80101125e-01, 4.57128957e-02, - 8.82216988e-01, 8.31498633e-01, 7.23474381e-01, 7.75788607e-01, 1.40864146e-01, 6.62092382e-01, - 5.13985168e-01, 3.00686418e-01, 8.70109949e-01, 2.43187753e-01, 2.89391938e-01, 2.84214238e-01, - 8.70985521e-01, 8.77491176e-01, 6.72537226e-01, 3.30929686e-01, 1.85934324e-01, 9.16222614e-01, - 6.18239142e-01, 2.64768597e-01, 5.76145451e-01, 8.62961369e-01, 6.84757925e-01, 7.60549082e-01, - 1.27645356e-01, 4.51004673e-01, 3.92292980e-01, 4.63170803e-01, 4.35449330e-02, 2.17583404e-01, - 5.71832605e-02, 2.06763039e-01, 3.70116249e-01, 2.09750028e-01, 6.17283019e-01, 8.62549231e-01, - 9.84156240e-02, 2.66249156e-01, 3.87635103e-01, 2.85591012e-02, 4.24826068e-01, 4.45795088e-01, - 6.86227676e-01, 1.08848960e-01, 5.96731841e-02, 3.71770228e-01, 1.91548833e-01, 6.95136078e-01, - 9.00700636e-01, 8.76363105e-01, 2.67334632e-01, 1.80619709e-01, 7.94060419e-01, 1.42854171e-02, - 1.09372387e-01, 8.74028108e-01, 6.46403232e-01, 4.86588834e-01, 5.93446175e-02, 6.11886291e-01, - 8.83865057e-01, 3.15879821e-01, 2.27043992e-01, 9.76764951e-01, 6.15620336e-01, 9.76199360e-01, - 2.40548962e-01, 3.21795663e-01, 8.75087904e-02, 8.11234663e-01, 6.96070480e-01, 8.12062321e-01, - 1.21958818e-01, 3.44348628e-02, 8.72630414e-01, 3.06162776e-01, 1.76043529e-02, 9.45894971e-01, - 5.33896401e-01, 6.21642973e-01, 4.93062535e-01, 4.48984262e-01, 2.24560379e-01, 4.24052195e-02, - 4.43447610e-01, 8.95646149e-01, 6.05220676e-01, 1.81840491e-01, 9.70831206e-01, 2.12563586e-02, - 6.92582693e-01, 7.55946922e-01, 7.95086143e-01, 6.05328941e-01, 3.99350764e-01, 4.32846636e-01, - 9.81114529e-01, 4.98266428e-01, 6.37127930e-03, 1.59085889e-01, 6.34682067e-05, 5.59429440e-01, - 7.38827633e-01, 8.93214770e-01, 2.16494306e-01, 9.35430573e-02, 4.75665868e-02, 7.80503518e-01, - 7.86240041e-01, 7.06854594e-01, 2.13725879e-02, 7.68246091e-01, 4.50234808e-01, 5.21231104e-01, - 5.01989826e-03, 4.22081572e-02, 1.65337732e-01, 8.54134740e-01, 4.99430262e-01, 8.94525601e-01, - 1.14028379e-01, 3.69739861e-01, 1.32955599e-01, 2.65563824e-01, 2.52811151e-01, 1.44792843e-01, - 6.88449594e-01, 4.44921417e-01, 8.23296587e-01, 1.93266317e-01, 1.19033309e-01, 1.36368966e-01, - 3.42600285e-01, 5.64505195e-01, 5.57594559e-01, 7.44257892e-01, 8.38231569e-02, 4.11548847e-01, - 3.21010077e-01, 8.55081359e-01, 4.30105779e-01, 1.16229135e-01, 9.87731964e-02, 3.14712335e-01, - 4.50880592e-01, 2.72289598e-01, 6.31615256e-01, 8.97432958e-01, 4.44764250e-01, 8.03776440e-01, - 2.68767748e-02, 2.43374608e-01, 4.02141103e-01, 4.98881209e-01, 5.33173003e-01, 8.82890436e-01, - 7.16149148e-01, 4.19664401e-01, 2.29335357e-01, 2.88637806e-01, 3.44696803e-01, 6.78171906e-01, - 5.69849716e-01, 5.86454477e-01, 3.54474989e-01, 9.03876540e-01, 6.45980000e-01, 6.34887593e-01, - 7.88039746e-02, 2.04814126e-01, 7.82251754e-01, 2.43147074e-01, 7.50951808e-01, 1.72799092e-02, - 2.95349590e-01, 6.57991826e-01, 8.81214312e-01, 5.73970708e-01, 2.77610881e-01, 1.82155097e-01, - 7.69797417e-02, 6.44792402e-01, 9.46950998e-01, 7.73064845e-01, 6.04733624e-01, 5.80094567e-01, - 1.67498426e-01, 2.66514296e-01, 6.50140368e-01, 1.91170299e-01, 2.08752199e-01, 3.01664091e-01, - 9.85033484e-01, 2.92909152e-01, 8.65816607e-01, 1.85222119e-01, 2.28814559e-01, 1.34286382e-02, - 2.89234322e-01, 8.18668708e-01, 4.71706924e-01, 9.23199803e-01, 2.80879188e-01, 1.47319284e-01, - 4.13915748e-01, 9.31274932e-02, 6.66322195e-01, 9.66953974e-01, 3.19405786e-01, 6.69486551e-01, - 5.03096313e-02, 6.95225201e-01, 5.78469859e-01, 6.29481655e-01, 1.39252534e-01, 1.22564968e-01, - 6.80663678e-01, 6.34607157e-01, 6.42765834e-01, 1.57127410e-02, 2.92132086e-01, 5.24423878e-01, - 4.68676824e-01, 2.86003928e-01, 7.18608322e-01, 8.95617933e-01, 5.48844309e-01, 1.74517278e-01, - 5.24379196e-01, 2.13526524e-01, 5.88375435e-01, 9.88560185e-01, 4.17435771e-01, 6.14438688e-01, - 9.53760881e-01, 5.27151288e-01, 7.03017278e-01, 3.44448559e-01, 4.47059676e-01, 2.83414901e-01, - 1.98979011e-01, 4.24917361e-01, 5.73172761e-01, 2.32398853e-02, 1.65887230e-01, 4.05552785e-01, - 9.29665524e-01, 2.26135696e-01, 9.20563384e-01, 7.65259963e-01, 4.54820075e-01, 8.97710267e-01, - 3.78559302e-03, 9.15219382e-01, 3.55705698e-01, 6.94905124e-01, 8.58540202e-01, 3.89790666e-01, - 2.49478206e-01, 7.93679304e-01, 4.75830027e-01, 4.40425353e-01, 3.70579459e-01, 1.40578049e-01, - 1.70386675e-01, 7.04056121e-01, 4.85963102e-01, 9.68450060e-01, 6.77178001e-01, 2.65934654e-01, - 2.58915007e-01, 6.70052890e-01, 2.61945109e-01, 8.46207759e-01, 1.01928951e-01, 2.85611334e-01, - 2.45776933e-01, 2.66658783e-01, 3.71724077e-01, 4.34319025e-01, 4.24407347e-01, 7.15417683e-01, - 8.07997684e-01, 1.64296275e-01, 6.01638065e-01, 8.60606804e-02, 2.68719187e-01, 5.11764101e-01, - 9.75844338e-01, 7.81226782e-01, 2.20925515e-01, 7.18135040e-01, 9.82395577e-01, 8.39160243e-01, - 9.08058083e-01, 6.88010677e-01, 8.14271847e-01, 5.12460821e-01, 1.17311345e-01, 5.96075228e-01, - 9.17455497e-01, 2.12052706e-01, 7.04074603e-01, 8.72872565e-02, 8.76047818e-01, 6.96235046e-01, - 8.54801557e-01, 2.49729159e-01, 9.76594604e-01, 2.87386363e-01, 2.36461559e-02, 9.94075254e-01, - 4.25193986e-01, 7.61869994e-01, 5.13334255e-01, 6.44711165e-02, 8.92156689e-01, 3.55235167e-01, - 1.08154647e-01, 8.78446825e-01, 2.43833016e-01, 9.23071293e-01, 2.72724115e-01, 9.46631338e-01, - 3.74510294e-01, 4.08451278e-02, 9.78392777e-01, 3.65079221e-01, 6.37199516e-01, 5.51144906e-01, - 5.25978080e-01, 1.42803678e-01, 4.05451674e-01, 7.79788219e-01, 6.26009784e-01, 3.35249497e-01, - 1.43159543e-02, 1.80363779e-01, 5.05096904e-01, 2.82619947e-01, 5.83561392e-01, 3.10951324e-01, - 8.73223968e-01, 4.38545619e-01, 4.81348800e-01, 6.68497085e-01, 3.79345401e-01, 9.58832501e-01, - 1.89869550e-01, 2.34083070e-01, 2.94066207e-01, 5.74892667e-02, 6.92106828e-02, 9.61127686e-02, - 6.72650672e-02, 8.47345378e-01, 2.80916761e-01, 7.32177357e-03, 9.80785961e-01, 5.73192225e-02, - 8.48781331e-01, 8.83225408e-01, 7.34398275e-01, 7.70381941e-01, 6.20778343e-01, 8.96822048e-01, - 5.40732486e-01, 3.69704071e-01, 5.77305837e-01, 2.08221827e-01, 7.34275341e-01, 1.06110900e-01, - 3.49496706e-01, 8.34948910e-01, 1.56403291e-02, 6.78576376e-01, 8.96141268e-01, 5.94835119e-01, - 1.43943153e-01, 3.49618530e-01, 2.10440392e-01, 3.46585620e-01, 1.05153093e-01, 3.45446174e-01, - 2.72177079e-01, 7.07946300e-01, 4.33717726e-02, 3.31232203e-01, 3.91874320e-01, 4.76338141e-01, - 6.22777789e-01, 2.95989228e-02, 4.32855769e-01, 7.61049310e-01, 3.63279149e-01, 9.47210350e-01, - 6.43721247e-01, 6.58025802e-01, 1.05247633e-02, 5.29974442e-01, 7.30675767e-01, 4.30041079e-01, - 6.62634841e-01, 8.25936616e-01, 9.91253704e-01, 6.79399281e-01, 5.44177006e-01, 7.52876048e-01, - 3.32139049e-01, 7.98732398e-01, 7.38865223e-01, 9.16055132e-01, 6.11736493e-01, 9.63672879e-01, - 1.83778839e-01, 7.27558919e-02, 5.91602822e-01, 3.25235484e-01, 2.34741217e-01, 9.52346277e-01, - 9.18556407e-01, 9.35373324e-01, 6.89209070e-01, 2.56049054e-01, 6.17975395e-01, 7.82285691e-01, - 9.84983432e-01, 6.62322741e-01, 2.04144457e-01, 3.98446577e-01, 1.38918297e-01, 3.05919921e-01, - 3.14043787e-01, 5.91072666e-01, 7.44703771e-01, 8.92272567e-01, 9.78017873e-01, 9.01203161e-01, - 1.41526372e-01, 4.14878484e-01, 6.80683651e-01, 5.01733152e-02, 8.14635389e-01, 2.27926375e-01, - 9.03269815e-01, 8.68443745e-01, 9.86939190e-01, 7.40779486e-01, 2.61005311e-01, 3.19276232e-01, - 9.69509248e-01, 1.11908818e-01, 4.49198556e-01, 1.27056715e-01, 3.84064823e-01, 5.14591811e-01, - 2.10747488e-01, 9.53884090e-01, 8.43167950e-01, 4.51187972e-01, 3.75331782e-01, 6.23566461e-01, - 3.55290379e-01, 2.95705968e-01, 1.69622690e-01, 1.42981830e-01, 2.72180991e-01, 9.46468040e-01, - 3.70932500e-01, 9.94292830e-01, 4.62587505e-01, 7.14817405e-01, 2.45370540e-02, 3.00906377e-01, - 5.75768304e-01, 9.71448393e-01, 6.95574827e-02, 3.93693854e-01, 5.29306116e-01, 5.04694554e-01, - 6.73797120e-02, 6.76596969e-01, 5.50948898e-01, 3.24909641e-01, 7.70337719e-01, 6.51842631e-03, - 3.03264879e-01, 7.61037886e-03, 2.72289601e-01, 1.50502041e-01, 6.71103888e-02, 7.41503703e-01, - 1.92088941e-01, 2.19043977e-01, 9.09320161e-01, 2.37993569e-01, 6.18107973e-02, 8.31447852e-01, - 2.23355609e-01, 1.84789435e-01, 4.16104518e-01, 4.21573859e-01, 8.72446305e-02, 2.97294197e-01, - 4.50328256e-01, 8.72199917e-01, 2.51279916e-01, 4.86219272e-01, 7.57071329e-01, 4.85655942e-01, - 1.06187277e-01, 4.92341327e-01, 1.46017513e-01, 5.25421017e-01, 4.22637906e-01, 2.24685018e-01, - 8.72648431e-01, 5.54051490e-01, 1.80745062e-01, 2.12756336e-01, 5.20883169e-01, 7.60363654e-01, - 8.30254678e-01, 5.00003328e-01, 4.69017439e-01, 6.38105527e-01, 3.50638261e-02, 5.22217353e-02, - 9.06516882e-02, 8.52975842e-01, 1.19985883e-01, 3.74926753e-01, 6.50302066e-01, 1.98875727e-01, - 6.28362507e-02, 4.32693501e-01, 3.10500685e-01, 6.20732833e-01, 4.58503272e-01, 3.20790034e-01, - 7.91284868e-01, 7.93054570e-01, 2.93406765e-01, 8.95399023e-01, 1.06441034e-01, 7.53085241e-02, - 8.67523104e-01, 1.47963482e-01, 1.25584706e-01, 3.81545040e-02, 6.34338619e-01, 1.76368938e-02, - 5.75553531e-02, 5.31607516e-01, 2.63869588e-01, 9.41945823e-01, 9.24028838e-02, 5.21496463e-01, - 7.74866558e-01, 5.65210610e-01, 7.28015327e-02, 6.51963790e-01, 8.94727453e-01, 4.49571590e-01, - 1.29932405e-01, 8.64026259e-01, 9.92599934e-01, 7.43721560e-01, 8.87300215e-01, 1.06369925e-01, - 8.11335531e-01, 7.87734900e-01, 9.87344678e-01, 5.32502820e-01, 4.42612382e-01, 9.64041183e-01, - 1.66085871e-01, 1.12937664e-01, 5.24423470e-01, 6.54689333e-01, 4.59119726e-01, 5.22774091e-01, - 3.08722276e-02, 6.26979315e-01, 4.49754105e-01, 8.07495757e-01, 2.34199499e-01, 1.67765675e-01, - 9.22168418e-01, 3.73210378e-01, 8.04432575e-01, 5.61890354e-01, 4.47025593e-01, 6.43155678e-01, - 2.40407640e-01, 5.91631279e-01, 1.59369206e-01, 7.75799090e-01, 8.32067212e-01, 5.59791576e-02, - 6.39105224e-01, 4.85274738e-01, 2.12630838e-01, 2.81431312e-02, 7.16205363e-01, 6.83885011e-01, - 5.23869697e-01, 9.99418314e-01, 8.35331599e-01, 4.69877463e-02, 6.74712562e-01, 7.99273684e-01, - 2.77001890e-02, 5.75809742e-01, 2.78513031e-01, 8.36209905e-01, 7.25472379e-01, 4.87173943e-01, - 7.88311357e-01, 9.64676177e-01, 1.75752651e-01, 4.98112580e-01, 8.08850418e-02, 6.40981131e-01, - 4.06647450e-01, 8.46539387e-01, 2.12620694e-01, 9.11012851e-01, 8.25041445e-01, 8.90065575e-01, - 9.63626055e-01, 5.96689242e-01, 1.63372670e-01, 4.51640148e-01, 3.43026542e-01, 5.80658851e-01, - 2.82327625e-01, 4.75535418e-01, 6.27760926e-01, 8.46314115e-01, 9.61961932e-01, 3.19806094e-01, - 5.05508062e-01, 5.28102944e-01, 6.13045057e-01, 7.44714938e-01, 1.50586073e-01, 7.91878033e-01, - 4.89839179e-01, 3.10496849e-01, 8.82309038e-01, 2.86922314e-01, 4.84687559e-01, 5.20838630e-01, - 4.62955493e-01, 2.38185305e-01, 5.47259907e-02, 7.10916137e-01, 7.31887202e-01, 6.25602317e-01, - 8.77741168e-01, 4.19881322e-01, 4.81222328e-01, 1.28224501e-01, 2.46034010e-01, 3.34971854e-01, - 7.37216484e-01, 5.62134821e-02, 7.14089724e-01, 9.85549393e-01, 4.66295827e-01, 3.08722434e-03, - 4.70237690e-01, 2.66524167e-01, 7.93875484e-01, 4.54795911e-02, 8.09702944e-01, 1.47709735e-02, - 1.70082405e-01, 6.35905179e-01, 3.75379109e-01, 4.30315011e-01, 3.15788760e-01, 5.58065230e-01, - 2.24643800e-01, 2.42142981e-01, 6.57283636e-01, 3.34921891e-01, 1.26588975e-01, 7.68064155e-01, - 9.43856291e-01, 4.47518596e-01, 5.44453573e-01, 9.95764932e-01, 7.16444391e-01, 8.51019765e-01, - 1.01179183e-01, 4.45473958e-01, 4.60327322e-01, 4.96895844e-02, 4.72907738e-01, 5.58987444e-01, - 3.41027487e-01, 1.56175026e-01, 7.58283148e-01, 6.83600909e-01, 2.14623396e-01, 3.27348880e-01, - 3.92517893e-01, 6.70418431e-01, 5.16440832e-01, 8.63140348e-01, 5.73277464e-01, 3.46608058e-01, - 7.39396341e-01, 7.20852434e-01, 2.35653246e-02, 3.89935659e-01, 7.53783745e-01, 6.34563528e-01, - 8.79339335e-01, 7.41599159e-02, 5.62433904e-01, 6.15553852e-01, 4.56956324e-01, 5.20047447e-01, - 5.26845015e-02, 5.58471266e-01, 1.63632233e-01, 5.38936665e-02, 6.49593683e-01, 2.56838748e-01, - 8.99035326e-01, 7.20847756e-01, 5.68954684e-01, 7.43684755e-01, 5.70924238e-01, 3.82318724e-01, - 4.89328290e-01, 5.62208561e-01, 4.97540804e-02, 4.18011085e-01, 6.88041565e-01, 2.16234653e-01, - 7.89548214e-01, 8.46136387e-01, 8.46816189e-01, 1.73842353e-01, 6.11627842e-02, 8.44440559e-01, - 4.50646654e-01, 3.74785037e-01, 4.87196697e-01, 4.56276448e-01, 9.13284391e-01, 4.15715464e-01, - 7.13597697e-01, 1.23641270e-02, 5.10031271e-01, 4.74601930e-02, 2.55731159e-01, 3.22090006e-01, - 1.91165703e-01, 4.51170940e-01, 7.50843157e-01, 4.42420576e-01, 4.25380660e-01, 4.50667257e-01, - 6.55689206e-01, 9.68257670e-02, 1.96528793e-01, 8.97343028e-01, 4.99940904e-01, 6.65504083e-01, - 9.41828079e-01, 4.54397338e-01, 5.61893331e-01, 5.09839880e-01, 4.53117514e-01, 8.96804127e-02, - 1.74888861e-01, 6.65641378e-01, 2.81668336e-01, 1.89532742e-01, 5.61668382e-01, 8.68330157e-02, - 8.25092797e-01, 5.18106324e-01, 1.71904024e-01, 3.68385523e-01, 1.62005436e-01, 7.48507399e-01, - 9.30274827e-01, 2.38198517e-01, 9.52222901e-01, 5.23587800e-01, 6.94384557e-01, 1.09338652e-01, - 4.83356794e-01, 2.73050402e-01, 3.68027050e-01, 5.92366466e-01, 1.83192289e-01, 8.60376029e-01, - 7.13926203e-01, 8.16750052e-01, 1.57890291e-01, 6.25691951e-01, 5.24831646e-01, 1.73873797e-01, - 1.02429784e-01, 9.17488471e-01, 4.03584434e-01, 9.31170884e-01, 2.79386137e-01, 8.77745206e-01, - 2.45200576e-01, 1.28896951e-01, 3.15713052e-01, 5.27874291e-01, 2.16444335e-01, 7.03883817e-01, - 7.74738919e-02, 8.42422142e-01, 3.75598924e-01, 3.51002411e-01, 6.22752776e-01, 4.82407943e-01, - 7.43107867e-01, 9.46182666e-01, 9.44344819e-01, 3.28124763e-01, 1.06147431e-01, 1.65102684e-01, - 3.84060507e-01, 2.91057722e-01, 7.68173662e-02, 1.03543651e-01, 6.76698940e-01, 1.43141994e-01, - 7.21342202e-01, 6.69471294e-03, 9.07298311e-01, 5.57080171e-01, 8.10954489e-01, 4.11120526e-01, - 2.06407453e-01, 2.59590556e-01, 7.58512718e-01, 5.79873897e-01, 2.92875650e-01, 2.83686529e-01, - 2.42829343e-01, 9.19323719e-01, 3.46832864e-01, 3.58238858e-01, 7.42827585e-01, 2.05760059e-01, - 9.58438860e-01, 5.66326411e-01, 6.60292846e-01, 5.61095078e-02, 6.79465531e-01, 7.05118513e-01, - 4.44713264e-01, 2.09732933e-01, 5.22732436e-01, 1.74396512e-01, 5.29356748e-01, 4.38475687e-01, - 4.94036404e-01, 4.09785794e-01, 6.40025507e-01, 5.79371821e-01, 1.57726118e-01, 6.04572263e-01, - 5.41072639e-01, 5.18847173e-01, 1.97093284e-01, 8.91767002e-01, 4.29050835e-01, 8.25490570e-01, - 3.87699807e-01, 4.50705808e-01, 2.49371643e-01, 3.36074898e-01, 9.29925118e-01, 6.65393649e-01, - 9.07275994e-01, 3.73075859e-01, 4.14044139e-03, 2.37463702e-01, 2.25893784e-01, 2.46900245e-01, - 4.50350196e-01, 3.48618117e-01, 5.07193932e-01, 5.23435142e-01, 8.13611417e-01, 8.92715622e-01, - 1.02623450e-01, 3.06088345e-01, 7.80461650e-01, 2.21453645e-01, 2.01419652e-01, 2.84254457e-01, - 3.68286735e-01, 7.39358243e-01, 8.97879394e-01, 9.81599566e-01, 7.56526442e-01, 7.37645545e-01, - 4.23976657e-02, 8.25922012e-01, 2.60956996e-01, 2.90702065e-01, 8.98388344e-01, 3.03733299e-01, - 8.49071471e-01, 3.45835425e-01, 7.65458276e-01, 5.68094872e-01, 8.93770930e-01, 9.93161641e-01, - 5.63368667e-02, 4.26548945e-01, 5.46745780e-01, 5.75674571e-01, 7.94599487e-01, 7.18935553e-02, - 4.46492976e-01, 6.40240123e-01, 2.73246969e-01, 2.00465968e-01, 1.30718835e-01, 1.92492005e-01, - 1.96617189e-01, 6.61271644e-01, 8.12687657e-01, 8.66342445e-01 - - }, - -4}}; - -typedef ConnectComponentsTest ConnectComponentsTestF_Int; -TEST_P(ConnectComponentsTestF_Int, Result) -{ - /** - * Verify the src & dst vertices on each edge have different colors - */ - EXPECT_TRUE(final_edges == params.n_row - 1); -} - -INSTANTIATE_TEST_CASE_P(ConnectComponentsTest, - ConnectComponentsTestF_Int, - ::testing::ValuesIn(fix_conn_inputsf2)); -}; // namespace sparse -}; // end namespace raft diff --git a/cpp/test/sparse/neighbors/cross_component_nn.cu b/cpp/test/sparse/neighbors/cross_component_nn.cu new file mode 100644 index 0000000000..7cadf25e88 --- /dev/null +++ b/cpp/test/sparse/neighbors/cross_component_nn.cu @@ -0,0 +1,1036 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// XXX: We allow the instantiation of masked_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(params.n_row); +// raft::linkage::cross_component_nn( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +// +// TODO: edge case testing. Reference: https://github.com/rapidsai/raft/issues/1669 + +#include +#include + +#include + +#include +#include +#include + +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../../test_utils.cuh" + +namespace raft { +namespace sparse { + +using namespace std; + +template +struct ConnectComponentsInputs { + value_idx n_row; + value_idx n_col; + std::vector data; + + int c; +}; + +template +class ConnectComponentsTest + : public ::testing::TestWithParam> { + protected: + void basicTest() + { + raft::resources handle; + + auto stream = resource::get_cuda_stream(handle); + + params = ::testing::TestWithParam>::GetParam(); + + raft::sparse::COO out_edges(resource::get_cuda_stream(handle)); + raft::sparse::COO out_edges_batched(resource::get_cuda_stream(handle)); + + rmm::device_uvector data(params.n_row * params.n_col, + resource::get_cuda_stream(handle)); + + raft::copy(data.data(), params.data.data(), data.size(), resource::get_cuda_stream(handle)); + + rmm::device_uvector indptr(params.n_row + 1, stream); + + /** + * 1. Construct knn graph + */ + raft::sparse::COO knn_graph_coo(stream); + + raft::sparse::neighbors::knn_graph(handle, + data.data(), + params.n_row, + params.n_col, + raft::distance::DistanceType::L2SqrtExpanded, + knn_graph_coo, + params.c); + + raft::sparse::convert::sorted_coo_to_csr( + knn_graph_coo.rows(), knn_graph_coo.nnz, indptr.data(), params.n_row + 1, stream); + + /** + * 2. Construct MST, sorted by weights + */ + rmm::device_uvector colors(params.n_row, stream); + + auto mst_coo = raft::mst::mst(handle, + indptr.data(), + knn_graph_coo.cols(), + knn_graph_coo.vals(), + params.n_row, + knn_graph_coo.nnz, + colors.data(), + stream, + false, + true); + + /** + * 3. cross_component_nn to fix connectivities + */ + raft::linkage::FixConnectivitiesRedOp red_op(params.n_row); + raft::linkage::cross_component_nn(handle, + out_edges, + data.data(), + colors.data(), + params.n_row, + params.n_col, + red_op, + params.n_row, + params.n_col); + + raft::linkage::cross_component_nn(handle, + out_edges_batched, + data.data(), + colors.data(), + params.n_row, + params.n_col, + red_op, + params.n_row / 2, + params.n_col / 2); + + ASSERT_TRUE(out_edges.nnz == out_edges_batched.nnz); + + ASSERT_TRUE( + devArrMatch(out_edges.rows(), out_edges_batched.rows(), out_edges.nnz, Compare())); + + ASSERT_TRUE( + devArrMatch(out_edges.cols(), out_edges_batched.cols(), out_edges.nnz, Compare())); + + ASSERT_TRUE(devArrMatch( + out_edges.vals(), out_edges_batched.vals(), out_edges.nnz, CompareApprox(1e-4))); + + /** + * Construct final edge list + */ + rmm::device_uvector indptr2(params.n_row + 1, stream); + + raft::sparse::convert::sorted_coo_to_csr( + out_edges.rows(), out_edges.nnz, indptr2.data(), params.n_row + 1, stream); + + auto output_mst = raft::mst::mst(handle, + indptr2.data(), + out_edges.cols(), + out_edges.vals(), + params.n_row, + out_edges.nnz, + colors.data(), + stream, + false, + false); + + resource::sync_stream(handle, stream); + + // The sum of edges for both MST runs should be n_rows - 1 + final_edges = output_mst.n_edges + mst_coo.n_edges; + } + + void SetUp() override { basicTest(); } + + void TearDown() override {} + + protected: + ConnectComponentsInputs params; + + value_idx final_edges; +}; + +const std::vector> fix_conn_inputsf2 = { + // Test n_clusters == n_points + {10, + 5, + {0.21390334, 0.50261639, 0.91036676, 0.59166485, 0.71162682, 0.10248392, 0.77782677, 0.43772379, + 0.4035871, 0.3282796, 0.47544681, 0.59862974, 0.12319357, 0.06239463, 0.28200272, 0.1345717, + 0.50498218, 0.5113505, 0.16233086, 0.62165332, 0.42281548, 0.933117, 0.41386077, 0.23264562, + 0.73325968, 0.37537541, 0.70719873, 0.14522645, 0.73279625, 0.9126674, 0.84854131, 0.28890216, + 0.85267903, 0.74703138, 0.83842071, 0.34942792, 0.27864171, 0.70911132, 0.21338564, 0.32035554, + 0.73788331, 0.46926692, 0.57570162, 0.42559178, 0.87120209, 0.22734951, 0.01847905, 0.75549396, + 0.76166195, 0.66613745}, + -1}, + // Test n_points == 100 + {100, + 10, + {6.26168372e-01, 9.30437651e-01, 6.02450208e-01, 2.73025296e-01, 9.53050619e-01, 3.32164396e-01, + 6.88942598e-01, 5.79163537e-01, 6.70341547e-01, 2.70140602e-02, 9.30429671e-01, 7.17721157e-01, + 9.89948537e-01, 7.75253347e-01, 1.34491522e-02, 2.48522428e-02, 3.51413378e-01, 7.64405834e-01, + 7.86373507e-01, 7.18748577e-01, 8.66998621e-01, 6.80316582e-01, 2.51288712e-01, 4.91078420e-01, + 3.76246281e-01, 4.86828710e-01, 5.67464772e-01, 5.30734742e-01, 8.99478296e-01, 7.66699088e-01, + 9.49339111e-01, 3.55248484e-01, 9.06046929e-01, 4.48407772e-01, 6.96395305e-01, 2.44277335e-01, + 7.74840000e-01, 5.21046603e-01, 4.66423971e-02, 5.12019638e-02, 8.95019614e-01, 5.28956953e-01, + 4.31536306e-01, 5.83857744e-01, 4.41787364e-01, 4.68656523e-01, 5.73971433e-01, 6.79989654e-01, + 3.19650588e-01, 6.12579596e-01, 6.49126442e-02, 8.39131142e-01, 2.85252117e-01, 5.84848929e-01, + 9.46507115e-01, 8.58440748e-01, 3.61528940e-01, 2.44215959e-01, 3.80101125e-01, 4.57128957e-02, + 8.82216988e-01, 8.31498633e-01, 7.23474381e-01, 7.75788607e-01, 1.40864146e-01, 6.62092382e-01, + 5.13985168e-01, 3.00686418e-01, 8.70109949e-01, 2.43187753e-01, 2.89391938e-01, 2.84214238e-01, + 8.70985521e-01, 8.77491176e-01, 6.72537226e-01, 3.30929686e-01, 1.85934324e-01, 9.16222614e-01, + 6.18239142e-01, 2.64768597e-01, 5.76145451e-01, 8.62961369e-01, 6.84757925e-01, 7.60549082e-01, + 1.27645356e-01, 4.51004673e-01, 3.92292980e-01, 4.63170803e-01, 4.35449330e-02, 2.17583404e-01, + 5.71832605e-02, 2.06763039e-01, 3.70116249e-01, 2.09750028e-01, 6.17283019e-01, 8.62549231e-01, + 9.84156240e-02, 2.66249156e-01, 3.87635103e-01, 2.85591012e-02, 4.24826068e-01, 4.45795088e-01, + 6.86227676e-01, 1.08848960e-01, 5.96731841e-02, 3.71770228e-01, 1.91548833e-01, 6.95136078e-01, + 9.00700636e-01, 8.76363105e-01, 2.67334632e-01, 1.80619709e-01, 7.94060419e-01, 1.42854171e-02, + 1.09372387e-01, 8.74028108e-01, 6.46403232e-01, 4.86588834e-01, 5.93446175e-02, 6.11886291e-01, + 8.83865057e-01, 3.15879821e-01, 2.27043992e-01, 9.76764951e-01, 6.15620336e-01, 9.76199360e-01, + 2.40548962e-01, 3.21795663e-01, 8.75087904e-02, 8.11234663e-01, 6.96070480e-01, 8.12062321e-01, + 1.21958818e-01, 3.44348628e-02, 8.72630414e-01, 3.06162776e-01, 1.76043529e-02, 9.45894971e-01, + 5.33896401e-01, 6.21642973e-01, 4.93062535e-01, 4.48984262e-01, 2.24560379e-01, 4.24052195e-02, + 4.43447610e-01, 8.95646149e-01, 6.05220676e-01, 1.81840491e-01, 9.70831206e-01, 2.12563586e-02, + 6.92582693e-01, 7.55946922e-01, 7.95086143e-01, 6.05328941e-01, 3.99350764e-01, 4.32846636e-01, + 9.81114529e-01, 4.98266428e-01, 6.37127930e-03, 1.59085889e-01, 6.34682067e-05, 5.59429440e-01, + 7.38827633e-01, 8.93214770e-01, 2.16494306e-01, 9.35430573e-02, 4.75665868e-02, 7.80503518e-01, + 7.86240041e-01, 7.06854594e-01, 2.13725879e-02, 7.68246091e-01, 4.50234808e-01, 5.21231104e-01, + 5.01989826e-03, 4.22081572e-02, 1.65337732e-01, 8.54134740e-01, 4.99430262e-01, 8.94525601e-01, + 1.14028379e-01, 3.69739861e-01, 1.32955599e-01, 2.65563824e-01, 2.52811151e-01, 1.44792843e-01, + 6.88449594e-01, 4.44921417e-01, 8.23296587e-01, 1.93266317e-01, 1.19033309e-01, 1.36368966e-01, + 3.42600285e-01, 5.64505195e-01, 5.57594559e-01, 7.44257892e-01, 8.38231569e-02, 4.11548847e-01, + 3.21010077e-01, 8.55081359e-01, 4.30105779e-01, 1.16229135e-01, 9.87731964e-02, 3.14712335e-01, + 4.50880592e-01, 2.72289598e-01, 6.31615256e-01, 8.97432958e-01, 4.44764250e-01, 8.03776440e-01, + 2.68767748e-02, 2.43374608e-01, 4.02141103e-01, 4.98881209e-01, 5.33173003e-01, 8.82890436e-01, + 7.16149148e-01, 4.19664401e-01, 2.29335357e-01, 2.88637806e-01, 3.44696803e-01, 6.78171906e-01, + 5.69849716e-01, 5.86454477e-01, 3.54474989e-01, 9.03876540e-01, 6.45980000e-01, 6.34887593e-01, + 7.88039746e-02, 2.04814126e-01, 7.82251754e-01, 2.43147074e-01, 7.50951808e-01, 1.72799092e-02, + 2.95349590e-01, 6.57991826e-01, 8.81214312e-01, 5.73970708e-01, 2.77610881e-01, 1.82155097e-01, + 7.69797417e-02, 6.44792402e-01, 9.46950998e-01, 7.73064845e-01, 6.04733624e-01, 5.80094567e-01, + 1.67498426e-01, 2.66514296e-01, 6.50140368e-01, 1.91170299e-01, 2.08752199e-01, 3.01664091e-01, + 9.85033484e-01, 2.92909152e-01, 8.65816607e-01, 1.85222119e-01, 2.28814559e-01, 1.34286382e-02, + 2.89234322e-01, 8.18668708e-01, 4.71706924e-01, 9.23199803e-01, 2.80879188e-01, 1.47319284e-01, + 4.13915748e-01, 9.31274932e-02, 6.66322195e-01, 9.66953974e-01, 3.19405786e-01, 6.69486551e-01, + 5.03096313e-02, 6.95225201e-01, 5.78469859e-01, 6.29481655e-01, 1.39252534e-01, 1.22564968e-01, + 6.80663678e-01, 6.34607157e-01, 6.42765834e-01, 1.57127410e-02, 2.92132086e-01, 5.24423878e-01, + 4.68676824e-01, 2.86003928e-01, 7.18608322e-01, 8.95617933e-01, 5.48844309e-01, 1.74517278e-01, + 5.24379196e-01, 2.13526524e-01, 5.88375435e-01, 9.88560185e-01, 4.17435771e-01, 6.14438688e-01, + 9.53760881e-01, 5.27151288e-01, 7.03017278e-01, 3.44448559e-01, 4.47059676e-01, 2.83414901e-01, + 1.98979011e-01, 4.24917361e-01, 5.73172761e-01, 2.32398853e-02, 1.65887230e-01, 4.05552785e-01, + 9.29665524e-01, 2.26135696e-01, 9.20563384e-01, 7.65259963e-01, 4.54820075e-01, 8.97710267e-01, + 3.78559302e-03, 9.15219382e-01, 3.55705698e-01, 6.94905124e-01, 8.58540202e-01, 3.89790666e-01, + 2.49478206e-01, 7.93679304e-01, 4.75830027e-01, 4.40425353e-01, 3.70579459e-01, 1.40578049e-01, + 1.70386675e-01, 7.04056121e-01, 4.85963102e-01, 9.68450060e-01, 6.77178001e-01, 2.65934654e-01, + 2.58915007e-01, 6.70052890e-01, 2.61945109e-01, 8.46207759e-01, 1.01928951e-01, 2.85611334e-01, + 2.45776933e-01, 2.66658783e-01, 3.71724077e-01, 4.34319025e-01, 4.24407347e-01, 7.15417683e-01, + 8.07997684e-01, 1.64296275e-01, 6.01638065e-01, 8.60606804e-02, 2.68719187e-01, 5.11764101e-01, + 9.75844338e-01, 7.81226782e-01, 2.20925515e-01, 7.18135040e-01, 9.82395577e-01, 8.39160243e-01, + 9.08058083e-01, 6.88010677e-01, 8.14271847e-01, 5.12460821e-01, 1.17311345e-01, 5.96075228e-01, + 9.17455497e-01, 2.12052706e-01, 7.04074603e-01, 8.72872565e-02, 8.76047818e-01, 6.96235046e-01, + 8.54801557e-01, 2.49729159e-01, 9.76594604e-01, 2.87386363e-01, 2.36461559e-02, 9.94075254e-01, + 4.25193986e-01, 7.61869994e-01, 5.13334255e-01, 6.44711165e-02, 8.92156689e-01, 3.55235167e-01, + 1.08154647e-01, 8.78446825e-01, 2.43833016e-01, 9.23071293e-01, 2.72724115e-01, 9.46631338e-01, + 3.74510294e-01, 4.08451278e-02, 9.78392777e-01, 3.65079221e-01, 6.37199516e-01, 5.51144906e-01, + 5.25978080e-01, 1.42803678e-01, 4.05451674e-01, 7.79788219e-01, 6.26009784e-01, 3.35249497e-01, + 1.43159543e-02, 1.80363779e-01, 5.05096904e-01, 2.82619947e-01, 5.83561392e-01, 3.10951324e-01, + 8.73223968e-01, 4.38545619e-01, 4.81348800e-01, 6.68497085e-01, 3.79345401e-01, 9.58832501e-01, + 1.89869550e-01, 2.34083070e-01, 2.94066207e-01, 5.74892667e-02, 6.92106828e-02, 9.61127686e-02, + 6.72650672e-02, 8.47345378e-01, 2.80916761e-01, 7.32177357e-03, 9.80785961e-01, 5.73192225e-02, + 8.48781331e-01, 8.83225408e-01, 7.34398275e-01, 7.70381941e-01, 6.20778343e-01, 8.96822048e-01, + 5.40732486e-01, 3.69704071e-01, 5.77305837e-01, 2.08221827e-01, 7.34275341e-01, 1.06110900e-01, + 3.49496706e-01, 8.34948910e-01, 1.56403291e-02, 6.78576376e-01, 8.96141268e-01, 5.94835119e-01, + 1.43943153e-01, 3.49618530e-01, 2.10440392e-01, 3.46585620e-01, 1.05153093e-01, 3.45446174e-01, + 2.72177079e-01, 7.07946300e-01, 4.33717726e-02, 3.31232203e-01, 3.91874320e-01, 4.76338141e-01, + 6.22777789e-01, 2.95989228e-02, 4.32855769e-01, 7.61049310e-01, 3.63279149e-01, 9.47210350e-01, + 6.43721247e-01, 6.58025802e-01, 1.05247633e-02, 5.29974442e-01, 7.30675767e-01, 4.30041079e-01, + 6.62634841e-01, 8.25936616e-01, 9.91253704e-01, 6.79399281e-01, 5.44177006e-01, 7.52876048e-01, + 3.32139049e-01, 7.98732398e-01, 7.38865223e-01, 9.16055132e-01, 6.11736493e-01, 9.63672879e-01, + 1.83778839e-01, 7.27558919e-02, 5.91602822e-01, 3.25235484e-01, 2.34741217e-01, 9.52346277e-01, + 9.18556407e-01, 9.35373324e-01, 6.89209070e-01, 2.56049054e-01, 6.17975395e-01, 7.82285691e-01, + 9.84983432e-01, 6.62322741e-01, 2.04144457e-01, 3.98446577e-01, 1.38918297e-01, 3.05919921e-01, + 3.14043787e-01, 5.91072666e-01, 7.44703771e-01, 8.92272567e-01, 9.78017873e-01, 9.01203161e-01, + 1.41526372e-01, 4.14878484e-01, 6.80683651e-01, 5.01733152e-02, 8.14635389e-01, 2.27926375e-01, + 9.03269815e-01, 8.68443745e-01, 9.86939190e-01, 7.40779486e-01, 2.61005311e-01, 3.19276232e-01, + 9.69509248e-01, 1.11908818e-01, 4.49198556e-01, 1.27056715e-01, 3.84064823e-01, 5.14591811e-01, + 2.10747488e-01, 9.53884090e-01, 8.43167950e-01, 4.51187972e-01, 3.75331782e-01, 6.23566461e-01, + 3.55290379e-01, 2.95705968e-01, 1.69622690e-01, 1.42981830e-01, 2.72180991e-01, 9.46468040e-01, + 3.70932500e-01, 9.94292830e-01, 4.62587505e-01, 7.14817405e-01, 2.45370540e-02, 3.00906377e-01, + 5.75768304e-01, 9.71448393e-01, 6.95574827e-02, 3.93693854e-01, 5.29306116e-01, 5.04694554e-01, + 6.73797120e-02, 6.76596969e-01, 5.50948898e-01, 3.24909641e-01, 7.70337719e-01, 6.51842631e-03, + 3.03264879e-01, 7.61037886e-03, 2.72289601e-01, 1.50502041e-01, 6.71103888e-02, 7.41503703e-01, + 1.92088941e-01, 2.19043977e-01, 9.09320161e-01, 2.37993569e-01, 6.18107973e-02, 8.31447852e-01, + 2.23355609e-01, 1.84789435e-01, 4.16104518e-01, 4.21573859e-01, 8.72446305e-02, 2.97294197e-01, + 4.50328256e-01, 8.72199917e-01, 2.51279916e-01, 4.86219272e-01, 7.57071329e-01, 4.85655942e-01, + 1.06187277e-01, 4.92341327e-01, 1.46017513e-01, 5.25421017e-01, 4.22637906e-01, 2.24685018e-01, + 8.72648431e-01, 5.54051490e-01, 1.80745062e-01, 2.12756336e-01, 5.20883169e-01, 7.60363654e-01, + 8.30254678e-01, 5.00003328e-01, 4.69017439e-01, 6.38105527e-01, 3.50638261e-02, 5.22217353e-02, + 9.06516882e-02, 8.52975842e-01, 1.19985883e-01, 3.74926753e-01, 6.50302066e-01, 1.98875727e-01, + 6.28362507e-02, 4.32693501e-01, 3.10500685e-01, 6.20732833e-01, 4.58503272e-01, 3.20790034e-01, + 7.91284868e-01, 7.93054570e-01, 2.93406765e-01, 8.95399023e-01, 1.06441034e-01, 7.53085241e-02, + 8.67523104e-01, 1.47963482e-01, 1.25584706e-01, 3.81545040e-02, 6.34338619e-01, 1.76368938e-02, + 5.75553531e-02, 5.31607516e-01, 2.63869588e-01, 9.41945823e-01, 9.24028838e-02, 5.21496463e-01, + 7.74866558e-01, 5.65210610e-01, 7.28015327e-02, 6.51963790e-01, 8.94727453e-01, 4.49571590e-01, + 1.29932405e-01, 8.64026259e-01, 9.92599934e-01, 7.43721560e-01, 8.87300215e-01, 1.06369925e-01, + 8.11335531e-01, 7.87734900e-01, 9.87344678e-01, 5.32502820e-01, 4.42612382e-01, 9.64041183e-01, + 1.66085871e-01, 1.12937664e-01, 5.24423470e-01, 6.54689333e-01, 4.59119726e-01, 5.22774091e-01, + 3.08722276e-02, 6.26979315e-01, 4.49754105e-01, 8.07495757e-01, 2.34199499e-01, 1.67765675e-01, + 9.22168418e-01, 3.73210378e-01, 8.04432575e-01, 5.61890354e-01, 4.47025593e-01, 6.43155678e-01, + 2.40407640e-01, 5.91631279e-01, 1.59369206e-01, 7.75799090e-01, 8.32067212e-01, 5.59791576e-02, + 6.39105224e-01, 4.85274738e-01, 2.12630838e-01, 2.81431312e-02, 7.16205363e-01, 6.83885011e-01, + 5.23869697e-01, 9.99418314e-01, 8.35331599e-01, 4.69877463e-02, 6.74712562e-01, 7.99273684e-01, + 2.77001890e-02, 5.75809742e-01, 2.78513031e-01, 8.36209905e-01, 7.25472379e-01, 4.87173943e-01, + 7.88311357e-01, 9.64676177e-01, 1.75752651e-01, 4.98112580e-01, 8.08850418e-02, 6.40981131e-01, + 4.06647450e-01, 8.46539387e-01, 2.12620694e-01, 9.11012851e-01, 8.25041445e-01, 8.90065575e-01, + 9.63626055e-01, 5.96689242e-01, 1.63372670e-01, 4.51640148e-01, 3.43026542e-01, 5.80658851e-01, + 2.82327625e-01, 4.75535418e-01, 6.27760926e-01, 8.46314115e-01, 9.61961932e-01, 3.19806094e-01, + 5.05508062e-01, 5.28102944e-01, 6.13045057e-01, 7.44714938e-01, 1.50586073e-01, 7.91878033e-01, + 4.89839179e-01, 3.10496849e-01, 8.82309038e-01, 2.86922314e-01, 4.84687559e-01, 5.20838630e-01, + 4.62955493e-01, 2.38185305e-01, 5.47259907e-02, 7.10916137e-01, 7.31887202e-01, 6.25602317e-01, + 8.77741168e-01, 4.19881322e-01, 4.81222328e-01, 1.28224501e-01, 2.46034010e-01, 3.34971854e-01, + 7.37216484e-01, 5.62134821e-02, 7.14089724e-01, 9.85549393e-01, 4.66295827e-01, 3.08722434e-03, + 4.70237690e-01, 2.66524167e-01, 7.93875484e-01, 4.54795911e-02, 8.09702944e-01, 1.47709735e-02, + 1.70082405e-01, 6.35905179e-01, 3.75379109e-01, 4.30315011e-01, 3.15788760e-01, 5.58065230e-01, + 2.24643800e-01, 2.42142981e-01, 6.57283636e-01, 3.34921891e-01, 1.26588975e-01, 7.68064155e-01, + 9.43856291e-01, 4.47518596e-01, 5.44453573e-01, 9.95764932e-01, 7.16444391e-01, 8.51019765e-01, + 1.01179183e-01, 4.45473958e-01, 4.60327322e-01, 4.96895844e-02, 4.72907738e-01, 5.58987444e-01, + 3.41027487e-01, 1.56175026e-01, 7.58283148e-01, 6.83600909e-01, 2.14623396e-01, 3.27348880e-01, + 3.92517893e-01, 6.70418431e-01, 5.16440832e-01, 8.63140348e-01, 5.73277464e-01, 3.46608058e-01, + 7.39396341e-01, 7.20852434e-01, 2.35653246e-02, 3.89935659e-01, 7.53783745e-01, 6.34563528e-01, + 8.79339335e-01, 7.41599159e-02, 5.62433904e-01, 6.15553852e-01, 4.56956324e-01, 5.20047447e-01, + 5.26845015e-02, 5.58471266e-01, 1.63632233e-01, 5.38936665e-02, 6.49593683e-01, 2.56838748e-01, + 8.99035326e-01, 7.20847756e-01, 5.68954684e-01, 7.43684755e-01, 5.70924238e-01, 3.82318724e-01, + 4.89328290e-01, 5.62208561e-01, 4.97540804e-02, 4.18011085e-01, 6.88041565e-01, 2.16234653e-01, + 7.89548214e-01, 8.46136387e-01, 8.46816189e-01, 1.73842353e-01, 6.11627842e-02, 8.44440559e-01, + 4.50646654e-01, 3.74785037e-01, 4.87196697e-01, 4.56276448e-01, 9.13284391e-01, 4.15715464e-01, + 7.13597697e-01, 1.23641270e-02, 5.10031271e-01, 4.74601930e-02, 2.55731159e-01, 3.22090006e-01, + 1.91165703e-01, 4.51170940e-01, 7.50843157e-01, 4.42420576e-01, 4.25380660e-01, 4.50667257e-01, + 6.55689206e-01, 9.68257670e-02, 1.96528793e-01, 8.97343028e-01, 4.99940904e-01, 6.65504083e-01, + 9.41828079e-01, 4.54397338e-01, 5.61893331e-01, 5.09839880e-01, 4.53117514e-01, 8.96804127e-02, + 1.74888861e-01, 6.65641378e-01, 2.81668336e-01, 1.89532742e-01, 5.61668382e-01, 8.68330157e-02, + 8.25092797e-01, 5.18106324e-01, 1.71904024e-01, 3.68385523e-01, 1.62005436e-01, 7.48507399e-01, + 9.30274827e-01, 2.38198517e-01, 9.52222901e-01, 5.23587800e-01, 6.94384557e-01, 1.09338652e-01, + 4.83356794e-01, 2.73050402e-01, 3.68027050e-01, 5.92366466e-01, 1.83192289e-01, 8.60376029e-01, + 7.13926203e-01, 8.16750052e-01, 1.57890291e-01, 6.25691951e-01, 5.24831646e-01, 1.73873797e-01, + 1.02429784e-01, 9.17488471e-01, 4.03584434e-01, 9.31170884e-01, 2.79386137e-01, 8.77745206e-01, + 2.45200576e-01, 1.28896951e-01, 3.15713052e-01, 5.27874291e-01, 2.16444335e-01, 7.03883817e-01, + 7.74738919e-02, 8.42422142e-01, 3.75598924e-01, 3.51002411e-01, 6.22752776e-01, 4.82407943e-01, + 7.43107867e-01, 9.46182666e-01, 9.44344819e-01, 3.28124763e-01, 1.06147431e-01, 1.65102684e-01, + 3.84060507e-01, 2.91057722e-01, 7.68173662e-02, 1.03543651e-01, 6.76698940e-01, 1.43141994e-01, + 7.21342202e-01, 6.69471294e-03, 9.07298311e-01, 5.57080171e-01, 8.10954489e-01, 4.11120526e-01, + 2.06407453e-01, 2.59590556e-01, 7.58512718e-01, 5.79873897e-01, 2.92875650e-01, 2.83686529e-01, + 2.42829343e-01, 9.19323719e-01, 3.46832864e-01, 3.58238858e-01, 7.42827585e-01, 2.05760059e-01, + 9.58438860e-01, 5.66326411e-01, 6.60292846e-01, 5.61095078e-02, 6.79465531e-01, 7.05118513e-01, + 4.44713264e-01, 2.09732933e-01, 5.22732436e-01, 1.74396512e-01, 5.29356748e-01, 4.38475687e-01, + 4.94036404e-01, 4.09785794e-01, 6.40025507e-01, 5.79371821e-01, 1.57726118e-01, 6.04572263e-01, + 5.41072639e-01, 5.18847173e-01, 1.97093284e-01, 8.91767002e-01, 4.29050835e-01, 8.25490570e-01, + 3.87699807e-01, 4.50705808e-01, 2.49371643e-01, 3.36074898e-01, 9.29925118e-01, 6.65393649e-01, + 9.07275994e-01, 3.73075859e-01, 4.14044139e-03, 2.37463702e-01, 2.25893784e-01, 2.46900245e-01, + 4.50350196e-01, 3.48618117e-01, 5.07193932e-01, 5.23435142e-01, 8.13611417e-01, 8.92715622e-01, + 1.02623450e-01, 3.06088345e-01, 7.80461650e-01, 2.21453645e-01, 2.01419652e-01, 2.84254457e-01, + 3.68286735e-01, 7.39358243e-01, 8.97879394e-01, 9.81599566e-01, 7.56526442e-01, 7.37645545e-01, + 4.23976657e-02, 8.25922012e-01, 2.60956996e-01, 2.90702065e-01, 8.98388344e-01, 3.03733299e-01, + 8.49071471e-01, 3.45835425e-01, 7.65458276e-01, 5.68094872e-01, 8.93770930e-01, 9.93161641e-01, + 5.63368667e-02, 4.26548945e-01, 5.46745780e-01, 5.75674571e-01, 7.94599487e-01, 7.18935553e-02, + 4.46492976e-01, 6.40240123e-01, 2.73246969e-01, 2.00465968e-01, 1.30718835e-01, 1.92492005e-01, + 1.96617189e-01, 6.61271644e-01, 8.12687657e-01, 8.66342445e-01 + + }, + -4}}; + +typedef ConnectComponentsTest ConnectComponentsTestF_Int; +TEST_P(ConnectComponentsTestF_Int, Result) +{ + /** + * Verify the src & dst vertices on each edge have different colors + */ + EXPECT_TRUE(final_edges == params.n_row - 1); +} + +INSTANTIATE_TEST_CASE_P(ConnectComponentsTest, + ConnectComponentsTestF_Int, + ::testing::ValuesIn(fix_conn_inputsf2)); + +template +struct MutualReachabilityFixConnectivitiesRedOp { + value_t* core_dists; + value_idx m; + + DI MutualReachabilityFixConnectivitiesRedOp() : m(0) {} + + MutualReachabilityFixConnectivitiesRedOp(value_t* core_dists_, value_idx m_) + : core_dists(core_dists_), m(m_){}; + + typedef typename raft::KeyValuePair KVP; + DI void operator()(value_idx rit, KVP* out, const KVP& other) const + { + if (rit < m && other.value < std::numeric_limits::max()) { + value_t core_dist_rit = core_dists[rit]; + value_t core_dist_other = max(core_dist_rit, max(core_dists[other.key], other.value)); + + value_t core_dist_out; + if (out->key > -1) { + core_dist_out = max(core_dist_rit, max(core_dists[out->key], out->value)); + } else { + core_dist_out = out->value; + } + + bool smaller = core_dist_other < core_dist_out; + out->key = smaller ? other.key : out->key; + out->value = smaller ? core_dist_other : core_dist_out; + } + } + + DI KVP operator()(value_idx rit, const KVP& a, const KVP& b) const + { + if (rit < m && a.key > -1) { + value_t core_dist_rit = core_dists[rit]; + value_t core_dist_a = max(core_dist_rit, max(core_dists[a.key], a.value)); + + value_t core_dist_b; + if (b.key > -1) { + core_dist_b = max(core_dist_rit, max(core_dists[b.key], b.value)); + } else { + core_dist_b = b.value; + } + + return core_dist_a < core_dist_b ? KVP(a.key, core_dist_a) : KVP(b.key, core_dist_b); + } + + return b; + } + + DI void init(value_t* out, value_t maxVal) const { *out = maxVal; } + DI void init(KVP* out, value_t maxVal) const + { + out->key = -1; + out->value = maxVal; + } + + DI void init_key(value_t& out, value_idx idx) const { return; } + DI void init_key(KVP& out, value_idx idx) const { out.key = idx; } + + DI value_t get_value(KVP& out) const { return out.value; } + DI value_t get_value(value_t& out) const { return out; } + + void gather(const raft::resources& handle, value_idx* map) + { + auto tmp_core_dists = raft::make_device_vector(handle, m); + thrust::gather(raft::resource::get_thrust_policy(handle), + map, + map + m, + core_dists, + tmp_core_dists.data_handle()); + raft::copy_async( + core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle)); + } + + void scatter(const raft::resources& handle, value_idx* map) + { + auto tmp_core_dists = raft::make_device_vector(handle, m); + thrust::scatter(raft::resource::get_thrust_policy(handle), + core_dists, + core_dists + m, + map, + tmp_core_dists.data_handle()); + raft::copy_async( + core_dists, tmp_core_dists.data_handle(), m, raft::resource::get_cuda_stream(handle)); + } +}; + +template +struct ConnectComponentsMutualReachabilityInputs { + value_idx n_row; + value_idx n_col; + std::vector data; + std::vector core_dists; + std::vector colors; + std::vector expected_rows; + std::vector expected_cols; + std::vector expected_vals; +}; + +template +class ConnectComponentsEdgesTest + : public ::testing::TestWithParam> { + protected: + void basicTest() + { + raft::resources handle; + + auto stream = resource::get_cuda_stream(handle); + + params = ::testing::TestWithParam< + ConnectComponentsMutualReachabilityInputs>::GetParam(); + + raft::sparse::COO out_edges_unbatched(resource::get_cuda_stream(handle)); + raft::sparse::COO out_edges_batched(resource::get_cuda_stream(handle)); + + rmm::device_uvector data(params.n_row * params.n_col, + resource::get_cuda_stream(handle)); + rmm::device_uvector core_dists(params.n_row, resource::get_cuda_stream(handle)); + rmm::device_uvector colors(params.n_row, resource::get_cuda_stream(handle)); + + raft::copy(data.data(), params.data.data(), data.size(), resource::get_cuda_stream(handle)); + raft::copy(core_dists.data(), + params.core_dists.data(), + core_dists.size(), + resource::get_cuda_stream(handle)); + raft::copy( + colors.data(), params.colors.data(), colors.size(), resource::get_cuda_stream(handle)); + + /** + * 3. cross_component_nn to fix connectivities + */ + MutualReachabilityFixConnectivitiesRedOp red_op(core_dists.data(), + params.n_row); + + raft::linkage::cross_component_nn(handle, + out_edges_unbatched, + data.data(), + colors.data(), + params.n_row, + params.n_col, + red_op, + params.n_row, + params.n_col); + + raft::linkage::cross_component_nn(handle, + out_edges_batched, + data.data(), + colors.data(), + params.n_row, + params.n_col, + red_op, + 11, + 1); + + ASSERT_TRUE(out_edges_unbatched.nnz == out_edges_batched.nnz && + out_edges_unbatched.nnz == params.expected_rows.size()); + + ASSERT_TRUE(devArrMatch(out_edges_unbatched.rows(), + params.expected_rows.data(), + out_edges_unbatched.nnz, + Compare())); + + ASSERT_TRUE(devArrMatch(out_edges_unbatched.cols(), + params.expected_cols.data(), + out_edges_unbatched.nnz, + Compare())); + + ASSERT_TRUE(devArrMatch(out_edges_unbatched.vals(), + params.expected_vals.data(), + out_edges_unbatched.nnz, + CompareApprox(1e-4))); + + ASSERT_TRUE(devArrMatch(out_edges_batched.rows(), + params.expected_rows.data(), + out_edges_batched.nnz, + Compare())); + + ASSERT_TRUE(devArrMatch(out_edges_batched.cols(), + params.expected_cols.data(), + out_edges_batched.nnz, + Compare())); + + ASSERT_TRUE(devArrMatch(out_edges_batched.vals(), + params.expected_vals.data(), + out_edges_batched.nnz, + CompareApprox(1e-4))); + } + + void SetUp() override { basicTest(); } + + void TearDown() override {} + + protected: + ConnectComponentsMutualReachabilityInputs params; +}; + +const std::vector> mr_fix_conn_inputsf2 = { + {100, + 2, + {-7.72642, -8.39496, 5.4534, 0.742305, -2.97867, 9.55685, 6.04267, 0.571319, -6.52184, + -6.31932, 3.64934, 1.40687, -2.17793, 9.98983, 4.42021, 2.33028, 4.73696, 2.94181, + -3.66019, 9.38998, -3.05358, 9.12521, -6.65217, -5.57297, -6.35769, -6.58313, -3.61553, + 7.81808, -1.77073, 9.18565, -7.95052, -6.39764, -6.60294, -6.05293, -2.58121, 10.0178, + -7.76348, -6.72638, -6.40639, -6.95294, -2.97262, 8.54856, -6.95673, -6.53896, -7.32614, + -6.02371, -2.1478, 10.5523, -2.54502, 10.5789, -2.96984, 10.0714, 3.22451, 1.55252, + -6.25396, -7.73727, -7.85431, -6.09303, -8.11658, -8.20057, -7.55965, -6.64786, 4.936, + 2.23423, 4.44752, 2.27472, -5.72103, -7.70079, -0.929985, 9.78172, -3.10984, 8.72259, + -2.44167, 7.58954, -2.18511, 8.6292, 5.55528, 2.30192, 4.73164, -0.0143992, -8.2573, + -7.81793, -2.98837, 8.82863, 4.60517, 0.804492, -3.83738, 9.21115, -2.62485, 8.71318, + 3.57758, 2.44676, -8.48711, -6.69548, -6.70645, -6.49479, -6.86663, -5.42658, 3.83139, + 1.47141, 2.02013, 2.79507, 4.64499, 1.73858, -1.69667, 10.3705, -6.61974, -6.09829, + -6.05757, -4.98332, -7.10309, -6.16611, -3.52203, 9.32853, -2.26724, 7.10101, 6.11777, + 1.4549, -4.23412, 8.452, -6.58655, -7.59446, 3.93783, 1.64551, -7.12502, -7.63385, + 2.72111, 1.94666, -7.14428, -4.15994, -6.66553, -8.12585, 4.70011, 4.43641, -7.76914, + -7.69592, 4.11012, 2.48644, 4.89743, 1.89872, 4.29716, 1.17089, -6.62913, -6.53366, + -8.07093, -6.22356, -2.16558, 7.25125, 4.73953, 1.46969, -5.91625, -6.46733, 5.43091, + 1.06378, -6.82142, -8.02308, 6.52606, 2.14775, 3.08922, 2.04173, -2.14756, 8.36917, + 3.85663, 1.65111, -1.68665, 7.79344, -5.01385, -6.40628, -2.52269, 7.95658, -2.30033, + 7.05462, -1.04355, 8.78851, 3.72045, 3.5231, -3.98772, 8.29444, 4.24777, 0.509655, + 4.72693, 1.67416, 5.7827, 2.7251, -3.41722, 7.60198, 5.22674, 4.16363, -3.1109, + 10.8666, -3.18612, 9.62596, -1.4782, 9.94557, 4.47859, 2.37722, -5.79658, -5.82631, + -3.34842, 8.70507}, + {0.978428, 1.01917, 0.608673, 1.45629, 0.310713, 0.689461, 0.701126, 0.63296, 0.774788, + 0.701648, 0.513282, 0.757651, 0.45638, 0.973111, 0.901396, 0.613692, 0.482497, 0.688143, + 0.72428, 0.666345, 0.58232, 0.554756, 0.710315, 0.903611, 0.694115, 0.796099, 0.639759, + 0.798998, 0.639839, 1.30727, 0.663729, 0.57476, 0.571348, 1.14662, 1.26518, 0.485068, + 0.78207, 0.791621, 1.01678, 1.28509, 1.14715, 0.381395, 0.850507, 0.788511, 0.588341, + 0.878516, 0.928669, 0.405874, 0.776421, 0.612274, 1.84963, 0.57476, 0.95226, 0.488078, + 1.24868, 0.515136, 0.589378, 0.903632, 1.01678, 1.09964, 0.666345, 0.713265, 0.877168, + 1.10053, 1.96887, 1.03574, 2.03728, 0.969553, 0.774788, 0.586338, 0.65168, 0.435472, + 0.664396, 0.790584, 0.678637, 0.715964, 0.865494, 0.978428, 1.59242, 0.861109, 0.833259, + 0.65168, 0.903632, 1.49599, 0.76347, 0.960453, 1.1848, 1.37398, 0.928957, 1.07848, + 0.661798, 1.21104, 1.04579, 1.89047, 1.24288, 0.529553, 0.903611, 0.620897, 0.882467, + 0.647189}, + {0, 1, 2, 1, 0, 1, 2, 1, 1, 2, 2, 0, 0, 2, 2, 0, 0, 2, 0, 0, 2, 0, 0, 2, 2, + 2, 1, 0, 0, 0, 0, 1, 1, 0, 2, 2, 2, 2, 1, 1, 0, 2, 1, 2, 2, 1, 0, 0, 0, 1, + 1, 1, 2, 0, 0, 0, 2, 2, 1, 2, 0, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 0, 0, 2, 1, + 0, 1, 0, 1, 1, 2, 1, 2, 0, 2, 2, 2, 1, 2, 1, 1, 1, 2, 1, 2, 2, 2, 1, 0, 2}, + {50, 54, 57, 63, 82, 87}, + {57, 63, 50, 54, 87, 82}, + {6.0764, 11.1843, 6.0764, 11.1843, 6.89004, 6.89004}}, + {1000, + 2, + {-6.59634, -7.13901, -6.13753, -6.58082, 5.19821, 2.04918, -2.96856, 8.16444, + -2.76879, 7.51114, -6.82261, -6.61152, 5.02008, 2.58376, 5.55621, 2.31966, + 4.86379, 3.33731, 5.84639, 1.15623, -2.17159, 8.60241, -4.97844, -6.94077, + -2.31014, 8.41407, 5.5582, 0.402669, 5.25265, 0.919754, 5.85298, 2.11489, + -3.29245, 8.69222, -1.9621, 8.81209, -1.53408, 8.86723, -2.18227, 8.79519, + 4.60519, 2.20738, -6.4759, -6.9043, -7.18766, -6.10045, -9.00148, -7.48793, + 4.01674, 1.41769, -2.45347, 10.1085, -3.20892, 9.22827, -3.18612, 9.62596, + 4.81977, 3.36517, 4.90693, 2.8628, -6.44269, -5.68946, -8.30144, -5.37878, + 4.61485, 2.79094, -1.98726, 9.31127, -3.66019, 9.38998, -6.58607, -8.23669, + -7.46015, -6.29153, 4.08468, 3.85433, -6.36842, -5.50645, -6.83602, -5.18506, + -0.627173, 10.3597, 3.98846, 1.48928, -2.9968, 8.58173, -7.2144, -7.28376, + -0.660242, 10.1409, -4.23528, -8.38308, -3.15984, 8.52716, -2.40987, 9.76567, + -8.7548, -6.76508, 4.56971, 0.312209, -7.5487, -5.8402, -1.6096, 9.32159, + 5.04813, 0.270586, -7.6525, -6.47306, -1.79758, 7.88964, -9.0153, -3.74236, + -3.5715, 9.48788, -1.65154, 8.85435, -3.47412, 9.70034, 6.31245, 2.39219, + 4.03851, 2.29295, -3.17098, 9.86672, -6.90693, -7.81338, -6.22373, -6.68537, + -3.22204, 9.12072, -0.365254, 9.6482, -7.76712, -7.31757, 4.15669, 3.54716, + 4.1937, 0.083629, -3.03896, 9.52755, -6.29293, -7.35501, -2.95926, 9.63714, + 4.02709, 1.58547, 4.56828, 1.93595, 5.6242, 1.75918, -7.36237, -7.83344, + 5.32177, 3.81988, -2.43183, 8.153, -1.97939, 10.4559, -3.49492, 9.51833, + 3.39602, 1.28026, -2.42215, 8.71528, -3.57682, 8.87191, -2.77385, 11.7345, + 5.71351, 0.946654, -6.50253, -6.90937, 4.08239, 0.603367, -5.64134, -6.85884, + -2.76177, 7.7665, -2.25165, 8.93984, -3.49071, 9.47639, -1.06792, 7.57842, + 5.15754, 1.24743, 3.63574, 1.20537, -6.07969, -8.49642, 4.12227, 2.19696, + -7.17144, -8.4433, -1.92234, 11.2047, 3.23237, 1.19535, 3.85389, 0.641937, + 4.82665, 1.21779, -7.68923, -6.45605, -7.00816, -8.76196, -5.12894, 9.83619, + -5.66247, -5.35879, 3.05598, 2.73358, 6.06038, 1.40242, -1.69568, 7.78342, + 5.13391, 2.23384, -2.96984, 10.0714, -5.36618, -6.2493, 5.55896, 1.6829, + 3.55882, 2.58911, 5.36155, 0.844118, -0.0634456, 9.14351, 4.88368, 1.40909, + -7.04675, -6.59753, -7.78333, -6.55575, 5.39881, 2.25436, -2.85189, 8.64285, + -2.22821, 8.39159, 3.88591, 1.69249, -7.55481, -7.02463, 4.60032, 2.65467, + -6.90615, -7.76198, -6.76005, -7.85318, 4.15044, 3.01733, -7.18884, -7.63227, + 4.68874, 2.01376, 3.51716, 2.35558, -3.81367, 9.68396, 4.42644, 3.4639, + 4.81758, 0.637825, -6.20705, -4.98023, -1.68603, 9.0876, -4.99504, -5.33687, + -1.77073, 9.18565, 4.86433, 3.02027, 4.20538, 1.664, 4.59042, 2.64799, + -3.09856, 9.86389, -3.02306, 7.95507, -6.32402, -6.79053, -7.67205, -7.18807, + -8.10918, -6.38341, -1.67979, 6.80315, 4.00249, 3.16219, -2.54391, 7.84561, + -3.22764, 8.80084, -2.63712, 8.05875, -2.41744, 7.02672, -6.71117, -5.56251, + 5.18348, 1.60256, -7.40824, -6.29375, -4.22233, 10.3682, 4.8509, 1.87646, + -2.99456, 9.09616, 5.1332, 2.15801, -2.27358, 9.78515, -6.73874, -8.64855, + 4.96124, 2.39509, -3.70949, 8.67978, -4.13674, 9.06237, 2.80367, 2.48116, + -0.876786, 7.58414, -3.7005, 9.67084, 6.48652, 0.903085, 6.28189, 2.98299, + -6.07922, -6.12582, -5.67921, -7.537, 4.55014, 3.41329, -1.63688, 9.19763, + -4.02439, 10.3812, 5.23053, 3.08187, -2.2951, 7.76855, -6.24491, -5.77041, + 6.02415, 2.53708, -6.91286, -7.08823, 4.83193, 1.66405, -7.07454, -5.74634, + -2.09576, 10.8911, 3.29543, 1.05452, -3.49973, 8.44799, 5.2922, 0.396778, + -2.54502, 10.5789, -6.38865, -6.14523, -1.75221, 8.09212, -9.30387, -5.99606, + -2.98113, 10.1032, -6.2017, -7.36802, 4.63628, 0.814805, -1.81905, 8.61307, + 4.88926, 3.55062, 3.08325, 2.57918, -2.51717, 10.4942, -5.75358, -6.9315, + 6.36742, 2.40949, 5.74806, 0.933264, 4.74408, 1.91058, -7.41496, -6.97064, + -2.98414, 8.36096, 6.72825, 1.83358, -2.95349, 9.39159, -3.35599, 7.49944, + 6.18738, 3.76905, -3.17182, 9.58488, 5.17863, 1.0525, -3.0397, 8.43847, + -2.23874, 8.96405, 3.04689, 2.41364, 6.14064, 2.82339, -6.33334, -6.87369, + -7.92444, -8.84647, 3.65129, 0.86958, 5.29842, 3.98337, -2.06538, 9.78892, + -6.89494, -6.30082, -2.52144, 8.11703, -8.11398, -7.47257, 5.3381, 2.36666, + -6.93452, -6.59456, -7.50634, -6.01772, 6.23438, 1.12621, -2.15218, 8.32138, + -7.04777, -7.3522, -2.52771, 8.72563, -2.77907, 8.03552, 4.29123, 1.62391, + -8.07551, -6.43551, -3.28202, 8.77747, -2.21308, 9.27534, -8.25153, -8.49367, + -3.54644, 8.82395, -8.05867, -5.69243, 4.46681, 1.98875, 3.8362, 3.61229, + -6.96231, -7.00186, 5.18993, 1.00483, -5.35116, -6.37227, 5.23298, 1.66362, + -5.68306, -7.03864, -9.03144, -7.59926, -6.10127, -7.4313, 4.83572, 0.994797, + -7.32695, -5.59909, 0.569683, 10.1339, 3.35957, 2.84563, -2.4122, 9.60944, + 5.00855, 1.57983, -2.57528, 7.80327, 3.96349, 3.77411, 4.59429, 2.21651, + -6.54765, -6.68961, 4.76798, 1.29212, -1.67351, 7.88458, 5.63615, 1.47941, + -2.5301, 9.13161, 4.26075, 1.76959, 4.67788, 2.0932, 4.39955, 1.59835, + 3.91274, 1.72565, -4.1786, 9.55765, -7.34566, -8.47481, 4.8364, 2.68217, + -7.36848, -7.99973, -5.84708, -5.7534, 5.37252, 1.89245, -2.1707, 8.599, + -1.3299, 9.0818, -6.79122, -5.40258, 5.56391, 1.78827, -0.194539, 7.14702, + 4.60489, 3.74397, 5.50995, 2.46885, -3.98772, 8.29444, -5.21837, -7.33721, + -1.63959, 10.3699, -5.92932, -5.1695, -5.88358, -7.6369, 4.11716, 3.02218, + -6.54114, -7.17551, 3.97179, 2.96521, -6.75325, -4.94118, 5.26169, 0.402945, + 3.25031, 0.327771, -0.44845, 10.7696, -2.15141, 9.57507, 7.04329, 1.91555, + -3.74615, 7.69383, -7.52318, -5.85015, -6.80419, -8.48208, -4.57664, 8.92517, + 4.57574, 2.30193, 4.84098, 3.02382, -9.43355, -5.94579, -3.52203, 9.32853, + 3.43018, 2.5731, -6.15725, -7.25294, -6.69861, -8.17694, -2.40955, 8.51081, + -4.82342, -7.98332, -7.10611, -6.51274, 5.86755, 0.763529, -6.56045, -5.53966, + -3.61553, 7.81808, 4.3825, 0.304586, -6.52818, -5.80996, 4.59972, 0.542395, + -6.90603, -6.59995, -6.3585, -6.23489, -6.01915, -7.46319, -5.38694, -7.15123, + -7.83475, -6.45651, 5.89564, 1.07856, -5.15266, -7.27975, -6.97978, -7.08378, + 5.83493, 0.449983, -2.62374, 10.2521, -7.34494, -6.98606, -6.79719, -8.33766, + 3.54757, 1.65676, -8.40528, -5.61753, -5.85556, -6.28758, 4.66862, 3.25162, + -6.26047, -4.82261, 4.61552, 4.11544, -1.36637, 9.76622, 4.2517, 2.14359, + -2.45099, 7.87132, -0.376164, 7.0622, 4.34493, 3.22091, 6.95921, 2.36649, + -6.70319, -7.24714, -5.56932, -5.48443, -7.43149, -4.32191, -3.23956, 9.23074, + -5.77255, -7.00049, 4.96601, 0.722056, -7.88617, -5.74023, 4.18757, -0.45071, + -7.12569, -7.72336, 5.27366, 2.38697, 3.93487, 1.9174, 3.19186, -0.225636, + -3.41722, 7.60198, -3.08286, 8.46743, -5.87905, -7.55073, -5.26425, -7.20243, + -2.97867, 9.55685, -1.23153, 8.42272, -2.33602, 9.3996, -3.33819, 8.45411, + -3.58009, 9.49676, 3.78152, 2.67348, -1.54582, 9.42707, -4.04331, 10.292, + 3.3452, 3.134, -2.75494, 8.74156, -3.26555, 7.59203, -7.27139, -7.80252, + 3.5293, 3.72544, 6.11642, 3.35326, 4.01611, 3.8872, 4.89591, 2.95586, + -7.06677, -5.89438, 4.19438, 3.42655, -6.11355, -5.65318, -7.59645, -8.74665, + -5.80362, -6.8588, 3.80453, 4.11832, 5.70655, 3.14247, -4.98084, 8.21739, + -1.87642, 11.285, 4.39864, 2.32523, -3.48388, 9.80137, 4.02836, 0.566509, + -2.41212, 9.98293, -5.40846, -7.08943, 4.01506, 1.99926, -3.43613, 8.95476, + -7.24458, -7.71932, 6.02204, 2.62188, -6.29999, -6.55431, 6.19038, 0.974816, + 3.55882, 3.02632, -7.06011, -3.687, -1.55877, 8.43738, -5.14711, -4.64881, + 4.7167, 0.690177, -7.90381, -5.02602, 4.17218, 2.31967, -0.643423, 9.48812, + -7.95237, -6.64086, -4.05986, 9.08285, -6.24158, -6.37927, -6.6105, -7.2233, + -6.21675, -5.70664, -3.29967, 9.48575, 3.41775, 2.68617, -2.24948, 8.10997, + -2.24931, 9.79611, -9.0523, -6.03269, -2.2587, 9.36073, 5.20965, 2.42088, + -3.10159, 8.1503, -6.67906, -5.73147, 4.0687, 2.54575, -1.24229, 8.30662, + -2.09627, 8.45056, -7.87801, -6.57832, 4.72216, 3.03865, -0.929985, 9.78172, + -8.56307, -7.68598, -7.05257, -5.1684, -7.09076, -7.86729, 4.61432, 3.1459, + -6.34133, -5.8076, -3.82943, 10.8457, -8.46082, -5.98507, 5.34763, 1.4107, + -1.68714, 10.9111, -1.67886, 8.1582, -0.623012, 9.18886, -4.21258, 8.95874, + -2.16744, 10.8905, -6.57158, -7.27176, 2.14047, 4.26411, -8.44217, -7.40916, + 5.29008, 1.87399, 4.31824, 4.04992, -3.77008, 9.93215, -2.72688, 10.1131, + -6.14278, -7.16144, -3.92457, 8.59364, -5.92649, -6.59299, 4.68369, 1.82617, + -6.89905, -7.18329, 3.95173, 4.22561, -7.66453, -6.23183, -2.44167, 7.58954, + -6.36603, -7.41281, -6.45081, -6.187, -6.6125, -6.37138, 5.46036, 2.48044, + -2.14756, 8.36917, -2.3889, 9.52872, 3.80752, 2.44459, -3.98778, 10.158, + -6.63887, -4.27843, -8.65266, -5.61819, -7.97003, -5.46918, -5.9604, -7.54825, + -0.916011, 8.50307, -3.69246, 6.97505, -7.98533, -7.09503, -2.30033, 7.05462, + 4.76218, 2.51647, -7.04981, -7.33334, 3.66401, 3.02681, -2.50408, 8.7797, + 7.19996, 1.87711, 4.01291, 3.78562, -0.356015, 8.24694, -0.958046, 9.12996, + 4.60675, 3.76773, 6.21945, 1.45031, 4.27744, 0.8535, -4.72232, -7.48582, + 6.03923, 2.8978, -3.26833, 9.16468, -7.97059, -7.29092, -2.3998, 9.74005, + -2.66721, 8.58741, -7.36269, -6.73332, -7.87893, -7.38488, 4.65023, 0.661333, + -4.8171, -7.94764, -4.11564, 9.21775, 4.80633, 2.46562, -2.72887, 9.3714, + -5.26735, -5.5652, 4.9826, 2.42992, -6.17018, -7.3156, 4.38084, 1.77682, + 5.35084, 2.41743, -2.61796, 9.416, 5.27229, 2.94572, -7.52315, -5.95227, + -1.45077, 7.25555, -3.79916, 7.71921, -2.23251, 9.84147, 3.70054, 1.82908, + -1.93831, 10.1499, -6.18324, -5.9248, -3.33142, 9.25797, -6.08536, -8.1344, + 5.95727, 2.17077, 4.87366, 0.417274, -6.529, -6.39092, -9.24256, -7.88984, + -6.36652, -7.13966, -3.90777, 9.57726, -7.06252, -5.50523, -2.26423, 8.50734, + -2.84498, 10.6833, 5.0391, 2.62037, -2.74815, 8.10672, 3.35945, 3.72796, + -4.11668, 9.19892, 5.66903, 2.44577, -1.63807, 8.68826, -7.42587, -6.48831, + 6.17063, 3.19193, -2.28511, 9.02688, -7.10088, -7.15692, 4.46293, 1.17487, + -5.91017, -6.45292, -2.26724, 7.10101, -2.43339, 8.33712, -4.63309, 8.48853, + -3.31769, 8.51253, -2.49078, 10.6907, -1.30798, 8.60621, 6.30535, 2.98754, + -5.79384, -6.78213, -1.93213, 8.81124, 4.55773, 3.09047, 6.37584, 2.17108, + 4.3927, 1.29119, -3.2245, 9.69388, -1.69634, 9.64392, 2.799, 0.693593, + -2.1426, 8.07441, -8.4505, -8.00688, 4.736, 1.51089, -2.5863, 9.35544, + -2.94924, 9.14503, 6.2054, 1.90742, 5.67172, 0.487609, -5.69071, -6.17181, + -8.24651, -7.10488, -7.34424, -6.67895, -6.71977, -7.90778, -1.82294, 7.40157, + -9.40991, -7.16611, -4.37999, 8.66277, -1.42615, 10.0681, -2.00828, 8.03673, + -7.50228, -6.6855, -5.65859, -6.29801, -8.02335, -6.77155, -3.40761, 9.50621, + -2.82447, 9.77326, -1.5938, 9.34304, -3.5213, 7.35943, -3.36961, 8.62973, + -7.01708, -5.92724, 5.20886, 3.60157, -1.71817, 8.1049, -2.46363, 8.36269, + -2.77809, 7.90776, -2.75459, 8.26055, -2.03596, 8.94146, -4.53434, 9.20074, + -7.44387, -6.69556, -6.90099, -7.62732, 3.29169, 2.71643, 6.08686, 2.16972, + -2.31111, 8.86993, -5.75046, 7.9899, 4.69951, 1.32623, 4.71851, -0.025031, + -6.42374, -4.71511, -8.04974, -8.68209, -3.16103, 9.06168, -6.18267, -7.21393, + -7.94202, -6.4518, -7.07697, -7.03138, 3.93554, 0.564708, -1.20372, 9.03529, + -7.10611, -7.83955, -7.47529, -5.50567, -6.15453, -6.36393, -2.98024, 9.24634, + -7.75761, -7.70699, -3.08597, 9.76968, -8.04954, -9.75237, 5.2534, 0.950377, + 5.63789, -0.923086, -5.7065, -6.51047, -8.02132, -7.07377, -8.28594, -6.96322, + -7.70722, -6.79397, -2.4962, 10.4678, 5.02846, 4.46617, 4.02648, 1.6707, + -0.319395, 8.20599, 4.74525, 0.639144, -1.0313, 8.49602, 4.08766, 2.6061, + 3.63826, 1.69207, 2.55795, 3.66963, 5.2826, 3.30232, -1.04355, 8.78851, + -6.84762, -7.63353, -4.70868, -7.056, 3.53651, -0.179721, -3.38482, 7.63149, + -5.9265, -6.36702, -0.986074, 9.5532, -2.42261, 8.85861, -7.42835, -6.78726, + -4.02857, 8.53005, -8.22675, -7.85172, -5.57529, -8.5426, 6.03009, 2.53098, + -7.10448, -7.53011, -3.4988, 8.8885, -2.62485, 8.71318, -6.39489, -7.72647, + 3.93789, 1.31027, 4.27627, 1.91622, -0.923181, 7.77647, -5.16017, 10.1058, + -6.44307, -5.97617, -7.24495, -6.69543, 6.27331, 0.826824, -6.55655, -7.13246, + 5.66245, 4.41292, -2.13805, 8.4103, 5.23463, 2.82659, -4.86624, -6.74357, + -6.14082, -6.26474, -2.67048, 9.41834, -1.26311, 6.9409, -7.20231, -7.13094, + -1.35109, 9.80595, 3.9906, 0.749229, -6.75696, -5.25543, 4.84826, -0.0685652, + -7.4914, -6.91715, 4.46725, 2.85683, -2.95571, 9.87068, 6.32381, 1.51429, + -6.81177, -6.02734, -2.57188, 9.96943, -4.28792, 10.5103, 3.65025, 2.91394, + -7.11856, -7.24693, -6.98693, -6.43239, 4.7651, 1.54376, 4.00092, 0.65008, + -7.14816, -7.7713, -7.58803, -8.39382, 4.3321, 2.19232, -7.89545, -6.81843, + -2.11475, 8.5933, -0.743743, 9.41927, 3.64849, -0.18022, -1.68665, 7.79344, + 4.00214, 1.44217, -6.96799, -7.25012, -1.58302, 10.9237, -6.68524, -7.23328, + 4.65831, 2.32075, 4.62024, 2.52566, -4.23412, 8.452, -0.822056, 9.89593, + -7.19868, -7.67614, -3.32742, 11.1067, 5.27861, 0.830165, 4.48982, 2.09875, + -6.58087, -7.6319, -0.880582, 7.63418, -7.01088, -6.80326, -7.31601, -6.98972, + -6.85883, -7.60811, 6.14328, 2.85053, -7.49206, -6.51861, -2.28174, 10.3214, + 4.81074, 1.78919, -5.58987, -6.20693, 4.08096, 2.35038, -1.5029, 8.43739, + 4.11536, 2.46254, -3.28299, 7.76963, 4.31953, 2.39734, 4.91146, 0.696421, + -1.4782, 9.94557, -3.34842, 8.70507, -6.97822, -6.86126, 4.10012, 1.19486, + -2.50395, 9.06127, 4.41891, 2.00006, -2.73266, 9.72829, 3.5436, 0.533119, + 5.78864, 0.233456, -6.62589, -6.41242, -2.21942, 11.0897, -6.76636, -8.31839, + -2.71732, 8.52129, -5.20972, -6.48544, 3.26056, 1.24224, 3.45228, 2.28299, + 4.72171, 1.87428, -7.52585, -5.1048, 5.0695, 2.18086, -6.55646, -7.02771, + 3.23727, 3.72275, 3.41411, 0.508795, -7.80698, -6.64174, -5.90443, -6.37902, + -0.387041, 10.0468, -1.3506, 8.1936, -6.08614, -8.62864, -5.91478, -5.26453, + -2.61623, 7.97904, 4.45459, 1.84335, -6.66643, -7.63208, 3.6729, 1.92546, + -1.32976, 8.54511, 6.31758, 1.41958, 4.63381, 2.81166, -7.01394, -6.0693, + -2.7786, 9.73183, -2.90131, 7.55077, -7.13842, -5.28146, 6.71514, 1.28398, + -6.98408, -7.04893, -3.03946, 8.22141, -2.76417, 10.5183, -7.35347, -6.89456, + 4.19345, 2.16726, -2.02819, 9.23817, 4.97076, 2.8067, -0.544473, 9.04955, + 4.90727, 2.29487, -6.31871, -7.17559, 3.71665, 0.621485, 4.7903, 2.33813, + -6.47994, -7.53147, -6.80958, -5.71823, -8.07326, -5.96096, 4.77342, 1.8207, + 5.71856, 1.93466, -2.70156, 9.31583, -2.1478, 10.5523, 4.78855, 1.63608, + 5.53507, 2.60834, -7.00058, -6.46058, 5.4738, 2.43235, -1.34603, 9.02452, + -7.5337, -8.71074, -7.30893, -7.57253, -5.33752, -4.87402, -7.01364, -6.86542, + -7.93331, -7.94791, -5.69392, -6.16116, -7.32291, -7.76491, -6.41965, -7.55783, + -7.87996, -7.55785, -6.69005, -5.87906, 3.92147, 2.86809, -1.5552, 9.66568, + 5.07989, 1.47112, -7.48524, -5.0541, -1.82724, 8.70402, -2.00421, 9.88004, + -2.62153, 8.79332, -7.52111, -6.44819, 4.06424, 2.09518, -6.65494, -5.94752, + 6.93878, 1.61033, -3.95728, 7.60682, 5.67016, 2.21196, -7.81507, -5.79413, + -2.41152, 8.24128, -3.83738, 9.21115, 4.5516, 4.55288, -5.75551, -5.93258, + 4.56545, 2.59384, -7.45614, -9.47115, -2.39568, 9.67642, 5.57816, 1.45712, + -7.48184, -6.41134, -1.99415, 12.867, -8.35854, -6.69675, -7.52559, -7.6793, + 5.7454, 3.1602, 2.94692, 1.87483, -8.77324, -6.66682, -3.21125, 8.68662, + -6.25806, -7.24972, 5.17639, 1.0747, -2.44897, 11.4775, -3.30172, 8.89955, + -2.85191, 8.21201, -8.85893, -6.1322, 4.08957, 1.30155, -5.88132, -7.31173, + -7.10309, -7.22943, -2.46068, 8.18334, -7.01226, -7.85464, 4.75411, 2.12347, + -3.42862, 10.5642, 7.16681, 1.4423, 5.42568, 2.39863, -6.00833, -8.22609, + -1.7619, 9.62466, -2.49527, 8.99016, -2.98837, 8.82863, -2.97262, 8.54856, + -1.34142, 9.26871, -5.99652, -6.95795, -1.87061, 7.35277, -8.68277, -8.46425, + -7.01808, -8.10441, -7.04269, -7.62501, -7.69783, -6.88348, -2.19829, 10.4896, + 4.67396, 1.2032, -5.58263, -6.90298, -5.69224, -4.29055, 4.77285, 1.27305, + -3.33469, 8.6929, -2.54195, 8.47086, 4.46492, 1.21742, 5.41158, -0.875373, + -8.68069, -7.42278, -3.88687, 8.07646, 4.6682, 2.00293, -8.29799, -8.64092, + -1.86382, 10.3829, -6.51234, -5.04193, 4.54458, 2.25219, -1.93264, 9.32554, + -3.06285, 7.81641, -6.90714, -5.10786, 4.69653, 2.50286, 6.43757, 2.61401, + -1.85483, 8.9587, 4.60224, 3.07647, 4.4492, 2.1906, 5.02181, 2.40321, + -2.22923, 7.8888, 5.68943, 1.43793, -6.71097, -6.43817, -5.00633, -5.80006, + -2.43763, 8.53663, 5.72577, 2.44787, -6.57079, -5.17789, -5.77867, -4.92176, + -6.57222, -6.06437, 3.96639, 2.25216, -7.95177, -9.80146, 4.92574, 2.30763, + -7.6221, -8.20013, -6.4132, -6.91575, 4.01432, 2.36897, 3.0833, 1.54505, + -1.99416, 9.52807, -7.85128, -8.25973, -0.86423, 8.76525, -6.31412, -8.64087, + -8.07355, -6.73717, -2.52821, 8.01176, -5.82357, -6.65687, -7.08865, -7.73063, + -5.56251, -6.99818, -2.12513, 8.98159, -6.89834, -7.26863, -7.92654, -6.34346, + 4.86201, 1.49442, 4.92905, 4.42847, -5.57789, -5.3186, 4.34232, 3.34888, + 2.64614, 2.34723, -4.10363, 8.41491, -2.18648, 8.18706, -3.39871, 8.19848, + -2.66098, 9.6026, -6.95927, -6.42774, -5.61392, -7.74628, 5.60376, 4.18369, + 5.28536, 4.13642, 4.8428, 0.457426, -6.33816, -6.12095, -2.4394, 8.62897, + 4.56938, 2.45967, 4.0582, 0.958413, 5.62164, 1.64834, 5.73119, 2.58231, + 4.66806, 1.96405, -6.71905, -6.87706, -2.18503, 8.88414, -6.03901, -6.33338, + -8.38435, -6.12005, 0.0641622, 9.0735, 5.19967, 3.05395, -5.48716, -7.13016, + -6.85541, -5.46789, -1.88353, 8.15713, 4.27891, 3.1325, -2.75816, 9.98586, + -2.03022, 9.34795, -7.66741, -7.50096, -3.39305, 9.16801, -8.49476, -5.71537, + -1.68378, 9.8278, -7.41559, -6.07205, -3.15577, 7.93274, 5.22381, 1.61388, + 3.65739, 1.74854, 4.94251, 1.21889, -7.12832, -5.27276, -9.58286, -6.20223, + -2.21613, 8.29993, 5.34799, 2.92987, 4.09496, 2.37231, -7.25183, -5.79136, + -6.46981, -7.12137, -6.28607, -9.8205, 4.52865, 1.06926, -3.10984, 8.72259, + 3.61865, 2.68153, -5.96604, -7.68329, 3.11435, 1.28126, -1.1064, 7.61243, + -2.17688, 8.2658, -3.27246, 7.2094, -5.55143, -6.32388, -1.69667, 10.3705, + -2.16558, 7.25125, -6.36572, -6.70053, 4.12259, 3.38252, -4.80554, -7.79949, + -5.23966, -6.13798, 4.21969, 1.69139, -1.98985, 10.547, -2.52269, 7.95658, + -6.75642, -6.32862, -3.51521, 7.8001, 4.70435, -0.00229688, 6.25359, 2.4267, + 5.82935, 0.745562, 5.24778, 2.15978, 5.48052, 1.32055, -3.05358, 9.12521, + -3.18922, 9.24654, 4.47276, 2.11988, 5.36751, 2.02512, -2.18511, 8.6292, + -2.48469, 9.51228, 5.57556, 3.24472, -2.58121, 10.0178, -6.12629, -6.49895, + -4.54732, 8.0062, -4.20166, 10.5438, -7.61422, -7.69036, -4.42797, 8.98777, + 4.45301, 1.53344, 4.59296, 2.45021, -6.81264, -6.36417, 4.62346, 3.16156, + -5.93007, -8.36501, -2.78425, 6.71237, -6.17141, -6.64689, -5.20608, 8.95999, + -7.30598, -5.73166, 4.39572, 2.93726, -1.89503, 9.77179, -5.683, -7.48989, + 4.80924, 0.559455, -2.17793, 9.98983, 5.23728, 2.67434, -7.03976, -6.20877, + 3.90435, 3.20926, -7.78536, -7.53388, -1.00684, 9.08838, -5.26741, -5.98327, + 3.28002, 2.71942, -1.47166, 8.50427, -2.32733, 9.26251, 5.16271, 1.39947, + -6.59093, -6.61979, -2.44492, 7.93654, -1.05805, 9.97356, -3.1109, 10.8666, + 3.38834, 3.41693, 4.83098, 2.01961, -2.74013, 9.71049, -3.34892, 8.41489, + 4.94768, 0.263001, 3.57477, 1.66795, 5.78915, 1.26999, -4.81812, -5.67174, + -1.88508, 9.64263, 3.69048, 4.60555, 4.03037, 1.7862, -7.4418, -7.08933}, + {0.127717, 0.211407, 0.195547, 0.21633, 0.39671, 0.229008, 0.20839, 0.169236, 0.314314, + 0.322473, 0.169506, 0.45499, 0.147819, 0.296502, 0.15198, 0.356444, 0.0992833, 0.220833, + 0.296206, 0.178067, 0.135359, 0.189725, 0.243099, 0.519986, 0.168105, 0.273465, 0.126033, + 0.18045, 0.282832, 0.193901, 0.213704, 0.425046, 0.203191, 0.228674, 0.209267, 0.355039, + 0.212918, 0.315495, 0.294112, 0.257576, 0.5786, 0.186019, 0.171919, 0.171919, 0.449151, + 1.34947, 0.171919, 0.16341, 0.641387, 0.342115, 0.267343, 0.246125, 0.277612, 0.181462, + 0.22944, 1.95598, 0.164897, 0.235803, 0.228273, 0.314629, 0.127403, 0.241241, 0.189362, + 0.151691, 0.130085, 0.526707, 0.217069, 0.282306, 0.531523, 0.177035, 0.169776, 0.20395, + 0.177165, 0.146628, 0.280013, 0.223033, 0.50947, 0.184133, 0.295329, 0.183219, 0.28166, + 0.179348, 0.276462, 1.00283, 0.248147, 0.214453, 0.231732, 0.170672, 0.256893, 0.133271, + 0.151137, 0.500823, 0.23678, 0.376983, 0.362061, 0.140013, 0.388863, 0.398552, 0.38015, + 0.190081, 0.167115, 0.206884, 0.473849, 1.05117, 0.435665, 0.323618, 0.326201, 0.32226, + 0.201787, 0.246496, 0.28325, 0.226596, 0.238153, 0.277268, 0.674629, 0.179433, 0.175651, + 0.154778, 0.178195, 0.192796, 0.103571, 0.227621, 0.201124, 0.160525, 0.160964, 0.240099, + 0.258027, 0.134127, 0.127717, 0.341378, 0.311595, 0.282306, 0.168988, 0.40775, 0.246125, + 0.583131, 0.236804, 0.238633, 0.194824, 0.169315, 0.244227, 0.249511, 0.189725, 0.305662, + 0.301415, 0.658641, 0.250944, 0.151792, 0.141383, 0.143843, 0.563347, 0.184216, 0.204155, + 0.221764, 0.314908, 0.144518, 0.228808, 0.255785, 0.163457, 0.424705, 0.170202, 0.312598, + 0.300629, 0.532614, 0.661392, 0.228273, 0.543432, 0.257175, 0.258994, 0.281413, 0.273897, + 0.246837, 0.293489, 0.25533, 0.260492, 0.213704, 0.3091, 0.17103, 0.172285, 0.241399, + 0.35999, 0.372243, 0.269191, 0.390239, 0.31761, 0.200593, 0.22197, 0.752914, 0.266571, + 0.13102, 0.268659, 0.293723, 0.356294, 0.296258, 0.264531, 0.15468, 0.358535, 0.243711, + 0.112147, 0.121659, 0.197101, 0.515292, 0.245628, 0.279863, 0.789807, 0.195156, 0.196073, + 0.149564, 0.118675, 0.389373, 0.233821, 0.176128, 0.481088, 0.360027, 0.553152, 0.208207, + 0.171608, 0.160489, 0.334298, 0.139426, 0.168603, 0.266199, 0.326458, 0.103571, 0.171208, + 0.130961, 0.190887, 0.177229, 0.241651, 0.115152, 0.196753, 0.481088, 0.230965, 0.354631, + 0.14591, 0.328543, 0.141544, 0.195888, 0.290379, 0.245954, 0.184547, 0.575214, 0.186929, + 0.28527, 0.292213, 1.20033, 0.281528, 0.15625, 0.211524, 0.186398, 0.298061, 0.147393, + 0.245349, 0.164527, 0.224771, 0.222382, 0.251643, 0.148835, 0.135359, 0.204967, 0.193024, + 0.486309, 0.389686, 0.211921, 0.307405, 0.38666, 0.26802, 0.16605, 0.323134, 0.268397, + 0.217894, 0.974118, 0.371618, 0.156201, 0.305787, 0.339305, 0.371032, 0.381765, 0.22747, + 0.24906, 0.100884, 0.253192, 0.314253, 0.388289, 0.580947, 1.00267, 0.241998, 0.489101, + 0.341501, 0.247423, 0.328311, 0.440281, 0.14927, 0.244469, 0.846828, 0.191725, 0.217429, + 0.123403, 0.322875, 0.145373, 0.757259, 0.190086, 0.316286, 0.268397, 0.296721, 0.440472, + 0.186848, 0.232134, 0.180239, 0.219724, 0.205886, 0.250975, 0.145636, 0.312476, 0.366418, + 0.128135, 0.315235, 0.264531, 0.161815, 0.31631, 0.296489, 0.37171, 0.197217, 0.195625, + 0.479579, 0.443037, 0.323347, 0.193616, 0.160251, 0.8952, 0.256291, 0.593345, 0.177165, + 0.409514, 0.847863, 0.111448, 0.210031, 0.251347, 0.351953, 0.705204, 0.117901, 0.182343, + 0.230179, 0.83632, 0.22104, 0.145163, 0.200326, 0.23431, 0.21868, 0.253575, 0.186562, + 0.192757, 0.172716, 0.27396, 0.258581, 0.327892, 0.376138, 0.223477, 0.302375, 0.145845, + 0.436902, 0.421794, 0.328543, 0.19246, 0.238889, 0.254866, 0.284674, 0.457849, 0.202937, + 0.392568, 0.453083, 0.782713, 0.465401, 0.178623, 0.304863, 0.190081, 0.228641, 0.255135, + 0.245037, 0.217526, 0.109584, 0.276462, 0.182301, 0.38582, 0.349942, 1.3889, 0.30235, + 0.796353, 0.160168, 0.643204, 0.153752, 0.410268, 0.186439, 0.256834, 0.185783, 0.0957629, + 0.226596, 0.197951, 0.17123, 0.192836, 0.18405, 0.575784, 0.228874, 0.201787, 0.241209, + 0.217386, 0.195751, 0.291585, 0.144531, 0.14176, 0.157635, 0.410268, 0.476338, 0.308148, + 0.148077, 0.152093, 0.196791, 0.568087, 0.414026, 0.250587, 0.473463, 0.293645, 0.396768, + 0.2766, 0.38664, 0.135034, 1.50827, 0.472527, 0.268418, 0.40383, 0.375914, 0.246496, + 0.176474, 0.340405, 0.220833, 0.138782, 0.159009, 0.444219, 0.259582, 0.33638, 0.195586, + 0.210974, 0.200288, 0.148129, 0.0974216, 0.211588, 0.280081, 0.44113, 0.773921, 0.553848, + 0.448079, 0.183136, 0.380854, 0.685021, 0.308767, 0.553276, 0.181578, 0.164759, 0.313889, + 0.137886, 0.545387, 0.278449, 0.736895, 0.360054, 0.358929, 0.457315, 0.343278, 0.507662, + 0.280829, 0.113886, 0.23146, 0.160584, 0.192796, 0.147561, 0.241272, 0.168988, 0.730511, + 0.27836, 0.179847, 0.22555, 0.418069, 0.158348, 0.128965, 0.179454, 0.126366, 0.164434, + 0.273633, 0.309556, 0.500823, 0.367852, 0.192875, 0.230262, 0.32724, 0.249969, 0.142618, + 0.494229, 0.36108, 0.227931, 0.23113, 0.742825, 0.190126, 0.33741, 0.280598, 0.145268, + 0.378423, 0.211921, 0.183594, 0.59201, 0.279563, 0.195683, 0.248101, 0.199754, 0.342494, + 0.174343, 0.14149, 0.28085, 0.175781, 0.518738, 0.17223, 0.489904, 0.181167, 0.354286, + 0.297824, 0.280829, 0.219412, 0.22814, 0.195625, 0.313949, 0.294708, 0.211551, 0.236255, + 0.666933, 0.204808, 0.52591, 0.180725, 0.186889, 0.246589, 0.410575, 0.338348, 0.206219, + 0.361766, 0.158143, 0.280816, 0.4149, 0.773082, 0.340046, 0.369672, 0.256923, 0.167195, + 0.197217, 0.252339, 0.172716, 0.191526, 0.263085, 0.345698, 0.168286, 0.243099, 0.434631, + 0.22944, 0.161862, 0.206589, 0.23457, 0.181924, 0.419063, 0.183427, 0.186152, 0.236352, + 0.306336, 0.149002, 1.50086, 0.188231, 0.442757, 0.485602, 0.466662, 0.17329, 0.141329, + 0.180619, 0.160061, 0.192569, 0.270999, 0.117901, 0.362693, 0.217561, 0.208975, 0.233658, + 0.175173, 1.10307, 0.14625, 1.31124, 0.237608, 0.286784, 0.325112, 0.2485, 0.259641, + 0.553152, 0.179039, 0.780781, 0.174758, 0.297824, 0.2558, 0.235949, 0.952186, 0.356744, + 0.312646, 0.189362, 0.574524, 0.705204, 0.213168, 0.225956, 0.424165, 0.169506, 0.137109, + 0.352451, 0.454554, 0.653302, 0.31261, 0.194412, 0.23719, 0.137886, 0.31498, 0.199085, + 0.203875, 0.597248, 1.10036, 0.196869, 0.22104, 0.451345, 0.105613, 0.683928, 0.135204, + 0.25533, 0.607871, 0.219724, 0.184464, 0.725001, 0.160061, 0.333407, 0.192569, 0.234147, + 0.47178, 0.161815, 0.242455, 0.215305, 0.410575, 0.242376, 0.211335, 0.462804, 0.275065, + 0.126878, 0.170404, 0.179433, 0.147244, 0.109584, 0.352905, 0.158215, 0.197604, 0.172407, + 0.407506, 0.645446, 0.313061, 0.165602, 0.136663, 0.55444, 0.15527, 0.133128, 0.125912, + 0.340405, 0.44521, 0.122783, 0.814526, 0.243773, 0.15743, 0.266743, 0.684458, 0.22221, + 0.181294, 0.193901, 0.258802, 0.167195, 0.292056, 0.132309, 0.227671, 0.117334, 0.271758, + 0.146185, 0.225042, 0.225964, 0.194863, 0.290274, 0.138438, 0.196714, 0.266012, 0.267771, + 0.162544, 0.244258, 0.358038, 0.522617, 0.192875, 0.45066, 0.330396, 0.223477, 0.42967, + 0.350884, 0.404655, 0.123155, 0.431583, 0.191675, 0.147354, 0.609034, 0.459487, 0.187337, + 0.215128, 0.604169, 0.330165, 0.494229, 0.40775, 0.167377, 0.192648, 0.234635, 0.275578, + 0.253094, 0.420063, 0.228299, 0.206478, 0.20395, 0.377656, 0.317393, 0.478623, 0.159009, + 0.217034, 0.300933, 0.139754, 0.153901, 0.261077, 0.22834, 0.449609, 0.157672, 0.176474, + 0.285704, 0.180186, 0.212738, 0.266428, 0.388313, 0.0954637, 0.298093, 0.251643, 0.330696, + 0.159572, 0.210666, 0.149411, 0.139618, 0.338472, 0.450304, 0.208793, 0.583609, 0.185865, + 0.400576, 0.21626, 0.174867, 0.239144, 0.249113, 0.200402, 0.275065, 0.238793, 0.205784, + 0.4475, 0.231262, 0.259082, 0.20934, 0.16806, 0.193616, 0.213811, 0.395632, 0.482465, + 0.274649, 0.307405, 0.165866, 0.334275, 0.683337, 0.368825, 0.14625, 0.780742, 0.163457, + 0.226596, 0.138713, 1.79155, 0.400443, 0.233658, 0.426399, 0.623024, 0.670955, 0.123588, + 0.110899, 0.173751, 0.651068, 0.199983, 0.190887, 0.541435, 0.21324, 0.266571, 0.134638, + 0.179348, 0.145636, 0.170929, 0.623252, 0.587738, 0.109688, 0.515314, 0.217666, 0.213311, + 0.249144, 0.187947, 0.270999, 0.268311, 0.469782, 0.763609, 0.32124, 0.146315, 0.265223, + 0.298694, 0.197623, 0.21349, 0.845778, 0.175466, 0.123588, 0.17223, 0.258603, 1.17119, + 0.538142, 0.407675, 0.120288, 0.587238, 0.244664, 0.333956, 0.132812, 0.21399, 0.302375, + 0.275882, 0.134284, 0.377555, 0.228541, 0.187307, 0.143804, 0.180545, 0.222451, 0.239638, + 0.188028, 0.46334, 0.175868, 0.242392, 0.314762, 0.44473, 0.21962, 0.175966, 1.12364, + 0.138837, 0.400576, 0.18184, 0.137706, 0.409763, 0.216894, 0.466662, 0.376604, 0.487155, + 0.283143, 0.118547, 0.221591, 0.122783, 0.179007, 0.16628, 0.180999, 0.239845, 0.169607, + 0.578402, 0.396537, 0.222288, 0.563237, 0.371238, 0.138658, 0.324336, 0.191526, 0.168603, + 0.357715, 0.640905, 0.460706, 0.220902, 0.240797, 0.164062, 0.157853, 0.34457, 0.196092, + 0.289353, 0.104597, 0.259641, 0.126878, 0.175781, 0.441458, 0.820108, 0.261864, 0.23431, + 0.254506, 0.271955, 0.227529, 0.22834, 0.196753, 0.224906, 0.193783, 0.419481, 0.236933, + 0.229706, 0.29785, 0.222947, 0.177606, 0.216911, 0.305188, 0.933438, 0.116666, 0.278483, + 0.0973824, 0.271224, 0.127717, 1.28139, 0.276283, 0.180704, 0.234554, 0.285984, 0.290172, + 0.49594, 0.135879, 0.436784, 0.206219, 0.342215, 0.374165, 0.182217, 0.274864, 0.625, + 0.356925, 0.194324, 0.342215, 0.113012, 0.155123, 0.254207, 0.438919, 0.262548, 0.302299, + 0.179528, 0.312744, 0.168513, 0.142618, 0.150543, 0.231361, 0.166004, 0.186725, 0.38848, + 0.179857, 0.182301, 0.629476, 0.44113, 0.289669, 0.328543, 0.279938, 0.14625, 0.187174, + 0.157635, 0.396749, 0.798931, 0.201541, 0.778619, 0.265883, 0.258027, 0.218576, 0.266571, + 0.160168, 0.230303, 0.273633, 0.233298, 0.30175, 0.217069, 0.345145, 0.397901, 0.224499, + 0.248101, 0.241335, 0.222947, 0.237094, 0.176518, 0.380032, 0.634775, 0.426193, 0.16362, + 0.231097, 0.219898, 0.343789, 0.275578, 0.282022, 0.628542, 0.232184, 0.848367, 0.200754, + 0.179177}, + {0, 0, 2, 3, 3, 0, 2, 2, 2, 2, 3, 0, 3, 2, 2, 2, 3, 3, 3, 3, 2, 0, 0, 0, 2, 3, 3, 3, 2, 2, 0, 0, + 2, 3, 3, 0, 0, 2, 0, 0, 3, 2, 3, 0, 3, 0, 3, 3, 0, 2, 0, 3, 2, 0, 3, 0, 3, 3, 3, 2, 2, 3, 0, 0, + 3, 3, 0, 2, 2, 3, 0, 3, 2, 2, 2, 0, 2, 3, 3, 3, 2, 3, 3, 3, 2, 0, 2, 0, 3, 3, 3, 3, 2, 2, 0, 2, + 0, 3, 2, 2, 2, 0, 0, 3, 0, 2, 2, 3, 2, 3, 0, 2, 2, 2, 3, 2, 0, 0, 2, 3, 3, 2, 0, 2, 0, 0, 2, 0, + 2, 2, 3, 2, 2, 0, 3, 0, 3, 2, 2, 2, 3, 3, 0, 0, 0, 3, 2, 3, 3, 3, 3, 0, 2, 0, 3, 2, 3, 2, 3, 0, + 2, 3, 3, 2, 3, 3, 2, 2, 0, 0, 2, 3, 3, 2, 3, 0, 2, 0, 2, 0, 3, 2, 3, 2, 3, 0, 3, 0, 3, 0, 2, 3, + 2, 2, 3, 0, 2, 2, 2, 0, 3, 2, 3, 3, 2, 3, 2, 3, 3, 2, 2, 0, 0, 2, 2, 3, 0, 3, 0, 2, 0, 0, 2, 3, + 0, 3, 3, 2, 0, 3, 3, 0, 3, 0, 2, 2, 0, 2, 0, 2, 0, 0, 0, 2, 0, 3, 2, 3, 2, 3, 2, 2, 0, 2, 3, 2, + 3, 2, 2, 2, 2, 3, 0, 2, 0, 0, 2, 3, 3, 0, 2, 3, 2, 2, 3, 0, 3, 0, 0, 2, 0, 2, 0, 2, 2, 3, 3, 2, + 3, 0, 0, 3, 2, 2, 0, 3, 2, 0, 0, 3, 0, 0, 2, 0, 3, 2, 0, 2, 0, 0, 0, 0, 0, 2, 0, 0, 2, 3, 0, 0, + 2, 0, 0, 2, 0, 2, 3, 2, 3, 3, 2, 2, 0, 0, 0, 3, 0, 2, 0, 2, 0, 2, 2, 2, 3, 3, 0, 0, 3, 3, 3, 3, + 3, 2, 3, 3, 2, 3, 3, 0, 2, 2, 2, 2, 0, 2, 0, 0, 0, 2, 2, 3, 3, 2, 3, 2, 3, 0, 2, 3, 0, 2, 0, 2, + 2, 0, 3, 0, 2, 0, 2, 3, 0, 3, 0, 0, 0, 3, 2, 3, 3, 0, 3, 2, 3, 0, 2, 3, 3, 0, 2, 3, 0, 0, 0, 2, + 0, 3, 0, 2, 3, 3, 3, 3, 3, 0, 2, 0, 2, 2, 3, 3, 0, 3, 0, 2, 0, 2, 0, 3, 0, 0, 0, 2, 3, 3, 2, 3, + 0, 0, 0, 0, 3, 3, 0, 3, 2, 0, 2, 3, 2, 2, 3, 3, 2, 2, 2, 0, 2, 3, 0, 3, 3, 0, 0, 2, 0, 3, 2, 3, + 0, 2, 0, 2, 2, 3, 2, 0, 3, 3, 3, 2, 3, 0, 3, 0, 2, 2, 0, 0, 0, 3, 0, 3, 3, 2, 3, 2, 3, 2, 3, 0, + 2, 3, 0, 2, 0, 3, 3, 3, 3, 3, 3, 2, 0, 3, 2, 2, 2, 3, 3, 2, 3, 0, 2, 3, 3, 2, 2, 0, 0, 0, 0, 3, + 0, 3, 3, 3, 0, 0, 0, 3, 3, 3, 3, 3, 0, 2, 3, 3, 3, 3, 3, 3, 0, 0, 2, 2, 3, 3, 2, 2, 0, 0, 3, 0, + 0, 0, 2, 3, 0, 0, 0, 3, 0, 3, 0, 2, 2, 0, 0, 0, 0, 3, 2, 2, 3, 2, 3, 2, 2, 2, 2, 3, 0, 0, 2, 3, + 0, 3, 3, 0, 3, 0, 0, 2, 0, 3, 3, 0, 2, 2, 3, 3, 0, 0, 2, 0, 2, 3, 2, 0, 0, 3, 3, 0, 3, 2, 0, 2, + 0, 2, 3, 2, 0, 3, 3, 2, 0, 0, 2, 2, 0, 0, 2, 0, 3, 3, 2, 3, 2, 0, 3, 0, 2, 2, 3, 3, 0, 3, 2, 2, + 0, 3, 0, 0, 0, 2, 0, 3, 2, 0, 2, 3, 2, 3, 2, 2, 3, 3, 0, 2, 3, 2, 3, 2, 2, 0, 3, 0, 3, 0, 2, 2, + 2, 0, 2, 0, 2, 2, 0, 0, 3, 3, 0, 0, 3, 2, 0, 2, 3, 2, 2, 0, 3, 3, 0, 2, 0, 3, 3, 0, 2, 3, 2, 3, + 2, 0, 2, 2, 0, 0, 0, 2, 2, 3, 3, 2, 2, 0, 2, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 2, 0, 3, 3, + 3, 0, 2, 0, 2, 3, 2, 0, 3, 3, 2, 0, 2, 0, 3, 2, 0, 3, 0, 0, 2, 2, 0, 3, 0, 2, 3, 3, 3, 0, 2, 0, + 0, 3, 0, 2, 3, 2, 2, 0, 3, 3, 3, 3, 3, 0, 3, 0, 0, 0, 0, 3, 2, 0, 0, 2, 3, 3, 2, 2, 0, 3, 2, 0, + 3, 0, 2, 3, 3, 0, 2, 2, 3, 2, 2, 2, 3, 2, 0, 0, 3, 2, 0, 0, 0, 2, 0, 2, 0, 0, 2, 2, 3, 0, 3, 0, + 0, 3, 0, 0, 0, 3, 0, 0, 2, 2, 0, 2, 2, 3, 3, 3, 3, 0, 0, 2, 2, 2, 0, 3, 2, 2, 2, 2, 2, 0, 3, 0, + 0, 3, 2, 0, 0, 3, 2, 3, 3, 0, 3, 0, 3, 0, 3, 2, 2, 2, 0, 0, 3, 2, 2, 0, 0, 0, 2, 3, 2, 0, 2, 3, + 3, 3, 0, 3, 3, 0, 2, 0, 0, 2, 3, 3, 0, 3, 2, 2, 2, 2, 2, 3, 3, 2, 2, 3, 3, 2, 3, 0, 3, 3, 0, 3, + 2, 2, 0, 2, 0, 3, 0, 3, 0, 2, 3, 0, 2, 3, 2, 0, 2, 0, 3, 0, 2, 3, 3, 2, 0, 3, 3, 3, 2, 2, 3, 3, + 2, 2, 2, 0, 3, 2, 2, 0}, + {271, 271, 329, 343, 387, 426, 426, 601}, + {426, 601, 426, 387, 343, 271, 329, 271}, + {3.70991, 4.43491, 3.76334, 9.43944, 9.43944, 3.70991, 3.76334, 4.43491}}}; + +typedef ConnectComponentsEdgesTest ConnectComponentsEdgesTestF_Int; +TEST_P(ConnectComponentsEdgesTestF_Int, Result) { EXPECT_TRUE(true); } + +INSTANTIATE_TEST_CASE_P(ConnectComponentsEdgesTest, + ConnectComponentsEdgesTestF_Int, + ::testing::ValuesIn(mr_fix_conn_inputsf2)); + +}; // namespace sparse +}; // end namespace raft diff --git a/cpp/test/util/device_atomics.cu b/cpp/test/util/device_atomics.cu index 5e8a67c8f6..355cb0d4dd 100644 --- a/cpp/test/util/device_atomics.cu +++ b/cpp/test/util/device_atomics.cu @@ -51,12 +51,12 @@ TEST(Raft, AtomicIncWarp) // Write all 1M thread indices to a unique location in `out_device` test_atomic_inc_warp_kernel<<>>(counter.data(), out_device.data()); - // Copy data to host - RAFT_CUDA_TRY(cudaMemcpy(out_host.data(), - (const void*)out_device.data(), - num_elts * sizeof(int), - cudaMemcpyDeviceToHost)); + RAFT_CUDA_TRY(cudaMemcpyAsync(out_host.data(), + (const void*)out_device.data(), + num_elts * sizeof(int), + cudaMemcpyDeviceToHost, + s)); // Check that count is correct and that each thread index is contained in the // array exactly once. diff --git a/dependencies.yaml b/dependencies.yaml index ffb3108b0b..60869603f6 100644 --- a/dependencies.yaml +++ b/dependencies.yaml @@ -109,7 +109,7 @@ dependencies: common: - output_types: [conda, requirements, pyproject] packages: - - cmake>=3.23.1,!=3.25.0 + - &cmake_ver cmake>=3.23.1,!=3.25.0 - cython>=0.29,<0.30 - ninja - scikit-build>=0.13.1 @@ -246,6 +246,7 @@ dependencies: common: - output_types: [conda] packages: + - *cmake_ver - gtest>=1.13.0 - gmock>=1.13.0 docs: diff --git a/docs/source/cpp_api/neighbors_cagra.rst b/docs/source/cpp_api/neighbors_cagra.rst index 68372bbb71..6613b0b06d 100644 --- a/docs/source/cpp_api/neighbors_cagra.rst +++ b/docs/source/cpp_api/neighbors_cagra.rst @@ -11,7 +11,7 @@ Please note that the CAGRA implementation is currently experimental and the API ``#include `` -namespace *raft::neighbors::experimental::cagra* +namespace *raft::neighbors::cagra* .. doxygengroup:: cagra :project: RAFT diff --git a/python/pylibraft/CMakeLists.txt b/python/pylibraft/CMakeLists.txt index 6a69aa44b4..44d02b2551 100644 --- a/python/pylibraft/CMakeLists.txt +++ b/python/pylibraft/CMakeLists.txt @@ -14,9 +14,14 @@ cmake_minimum_required(VERSION 3.23.1 FATAL_ERROR) +include(../../fetch_rapids.cmake) + set(pylibraft_version 23.10.00) -include(../../fetch_rapids.cmake) +# We always need CUDA for pylibraft because the raft dependency brings in a header-only cuco +# dependency that enables CUDA unconditionally. +include(rapids-cuda) +rapids_cuda_init_architectures(pylibraft) project( pylibraft @@ -25,7 +30,7 @@ project( # language to be enabled here. The test project that is built in scikit-build to verify # various linking options for the python library is hardcoded to build with C, so until # that is fixed we need to keep C. - C CXX + C CXX CUDA ) option(FIND_RAFT_CPP "Search for existing RAFT C++ installations before defaulting to local files" @@ -51,15 +56,6 @@ endif() include(rapids-cython) if(NOT raft_FOUND) - # TODO: This will not be necessary once we upgrade to CMake 3.22, which will pull in the required - # languages for the C++ project even if this project does not require those languages. - include(rapids-cuda) - rapids_cuda_init_architectures(pylibraft) - enable_language(CUDA) - # Since pylibraft only enables CUDA optionally we need to manually include the file that - # rapids_cuda_init_architectures relies on `project` including. - include("${CMAKE_PROJECT_pylibraft_INCLUDE}") - set(BUILD_TESTS OFF) set(BUILD_PRIMS_BENCH OFF) set(BUILD_ANN_BENCH OFF)