From 5f2bd19276d4314556795c86fa2230ccac988583 Mon Sep 17 00:00:00 2001 From: Bradley Dice Date: Thu, 28 Mar 2024 12:33:39 -0500 Subject: [PATCH 1/4] Use `conda env create --yes` instead of `--force`. (#2247) --- ci/build_docs.sh | 2 +- ci/check_style.sh | 4 ++-- ci/test_cpp.sh | 2 +- ci/test_python.sh | 2 +- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/ci/build_docs.sh b/ci/build_docs.sh index 3d72c815db..9605b52f8b 100755 --- a/ci/build_docs.sh +++ b/ci/build_docs.sh @@ -11,7 +11,7 @@ rapids-dependency-file-generator \ --file_key docs \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n docs +rapids-mamba-retry env create --yes -f env.yaml -n docs conda activate docs rapids-print-env diff --git a/ci/check_style.sh b/ci/check_style.sh index 0ee6e88e58..d7baa88e8f 100755 --- a/ci/check_style.sh +++ b/ci/check_style.sh @@ -1,5 +1,5 @@ #!/bin/bash -# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# Copyright (c) 2020-2024, NVIDIA CORPORATION. set -euo pipefail @@ -11,7 +11,7 @@ rapids-dependency-file-generator \ --file_key checks \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n checks +rapids-mamba-retry env create --yes -f env.yaml -n checks conda activate checks # Run pre-commit checks diff --git a/ci/test_cpp.sh b/ci/test_cpp.sh index fb2d025f9c..f83ddf616d 100755 --- a/ci/test_cpp.sh +++ b/ci/test_cpp.sh @@ -14,7 +14,7 @@ rapids-dependency-file-generator \ --file_key test_cpp \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch)" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n test +rapids-mamba-retry env create --yes -f env.yaml -n test # Temporarily allow unbound variables for conda activation. set +u diff --git a/ci/test_python.sh b/ci/test_python.sh index aae8ae03ea..f5b188ca0b 100755 --- a/ci/test_python.sh +++ b/ci/test_python.sh @@ -14,7 +14,7 @@ rapids-dependency-file-generator \ --file_key test_python \ --matrix "cuda=${RAPIDS_CUDA_VERSION%.*};arch=$(arch);py=${RAPIDS_PY_VERSION}" | tee env.yaml -rapids-mamba-retry env create --force -f env.yaml -n test +rapids-mamba-retry env create --yes -f env.yaml -n test # Temporarily allow unbound variables for conda activation. set +u From eabe3b00dad4225b00cb93d16cf1918b213b0ae3 Mon Sep 17 00:00:00 2001 From: tsuki <12711693+enp1s0@users.noreply.github.com> Date: Thu, 4 Apr 2024 02:11:48 +0900 Subject: [PATCH 2/4] Add CAGRA-Q subspace dim = 4 support (#2244) This PR adds the support for subspace dim (pq_dim) = 4 in CAGRA-Q Authors: - tsuki (https://github.com/enp1s0) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2244 --- .../neighbors/detail/cagra/cagra_search.cuh | 3 +- .../detail/cagra/compute_distance_vpq.cuh | 29 ++++++++++--------- .../raft/neighbors/detail/vpq_dataset.cuh | 2 +- cpp/test/neighbors/ann_cagra_vpq.cuh | 4 +-- 4 files changed, 21 insertions(+), 17 deletions(-) mode change 100755 => 100644 cpp/test/neighbors/ann_cagra_vpq.cuh diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index d30f69ddcd..ccfe3c7e2d 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -166,7 +166,8 @@ void launch_vpq_search_main_core( CagraSampleFilterT sample_filter) { RAFT_EXPECTS(vpq_dset->pq_bits() == 8, "Only pq_bits = 8 is supported for now"); - RAFT_EXPECTS(vpq_dset->pq_len() == 2, "Only pq_len 2 is supported for now"); + RAFT_EXPECTS(vpq_dset->pq_len() == 2 || vpq_dset->pq_len() == 4, + "Only pq_len 2 or 4 is supported for now"); RAFT_EXPECTS(vpq_dset->dim() % vpq_dset->pq_dim() == 0, "dim must be a multiple of pq_dim at the moment"); diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh index 0204addba7..e73d24bfb6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance_vpq.cuh @@ -33,6 +33,8 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t::QUERY_T; + static_assert(std::is_same_v, "Only CODE_BOOK_T = `half` is supported now"); + const std::uint8_t* encoded_dataset_ptr; const std::uint32_t encoded_dataset_dim; const std::uint32_t n_subspace; @@ -53,18 +55,19 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t(smem_ptr); // Copy PQ table - if constexpr (std::is_same::value) { - for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { - half2 buf2; - buf2.x = pq_code_book_ptr[i]; - buf2.y = pq_code_book_ptr[i + 1]; - (reinterpret_cast(smem_pq_code_book_ptr + i))[0] = buf2; - } - } else { - for (unsigned i = threadIdx.x; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x) { - // TODO: vectorize - smem_pq_code_book_ptr[i] = pq_code_book_ptr[i]; - } + for (unsigned i = threadIdx.x * 2; i < (1 << PQ_BITS) * PQ_LEN; i += blockDim.x * 2) { + half2 buf2; + buf2.x = pq_code_book_ptr[i]; + buf2.y = pq_code_book_ptr[i + 1]; + + // Change the order of PQ code book array to reduce the + // frequency of bank conflicts. + constexpr auto num_elements_per_bank = 4 / utils::size_of(); + constexpr auto num_banks_per_subspace = PQ_LEN / num_elements_per_bank; + const auto j = i / num_elements_per_bank; + const auto smem_index = + (j / num_banks_per_subspace) + (j % num_banks_per_subspace) * (1 << PQ_BITS); + reinterpret_cast(smem_pq_code_book_ptr)[smem_index] = buf2; } } @@ -136,7 +139,7 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t::value) && (PQ_LEN % 2 == 0)) { + if constexpr (PQ_LEN % 2 == 0) { // **** Use half2 for distance computation **** half2 norm2{0, 0}; #pragma unroll diff --git a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh index f1321ba343..f6cd2a1ceb 100644 --- a/cpp/include/raft/neighbors/detail/vpq_dataset.cuh +++ b/cpp/include/raft/neighbors/detail/vpq_dataset.cuh @@ -81,7 +81,7 @@ auto fill_missing_params_heuristics(const vpq_params& params, const DatasetT& da vpq_params r = params; double n_rows = dataset.extent(0); size_t dim = dataset.extent(1); - if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{2}); } + if (r.pq_dim == 0) { r.pq_dim = raft::div_rounding_up_safe(dim, size_t{4}); } if (r.pq_bits == 0) { r.pq_bits = 8; } if (r.vq_n_centers == 0) { r.vq_n_centers = raft::round_up_safe(std::sqrt(n_rows), 8); } if (r.vq_kmeans_trainset_fraction == 0) { diff --git a/cpp/test/neighbors/ann_cagra_vpq.cuh b/cpp/test/neighbors/ann_cagra_vpq.cuh old mode 100755 new mode 100644 index 503b1a413a..6b24bca921 --- a/cpp/test/neighbors/ann_cagra_vpq.cuh +++ b/cpp/test/neighbors/ann_cagra_vpq.cuh @@ -158,7 +158,7 @@ class AnnCagraVpqTest : public ::testing::TestWithParam { resource::sync_stream(handle_); } - const auto vpq_k = ps.k * 16; + const auto vpq_k = ps.k * 4; { rmm::device_uvector distances_dev(vpq_k * ps.n_queries, stream_); rmm::device_uvector indices_dev(vpq_k * ps.n_queries, stream_); @@ -319,7 +319,7 @@ const std::vector vpq_inputs = raft::util::itertools::product {1000, 10000}, // n_rows {128, 132, 192, 256, 512, 768}, // dim {8, 12}, // k - {2}, // pq_len + {2, 4}, // pq_len {8}, // pq_bits {graph_build_algo::NN_DESCENT}, // build_algo {search_algo::SINGLE_CTA, search_algo::MULTI_CTA}, // algo From 8a68518fd5a0ae0e750cd8f77b02f73efc111f5c Mon Sep 17 00:00:00 2001 From: Micka Date: Thu, 4 Apr 2024 19:26:04 +0200 Subject: [PATCH 3/4] Fix time computation in CAGRA notebook (#2231) Closes #2230. I am also adding `nn_descent` to the build parameters of cagra Authors: - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2231 --- .../VectorSearch_QuestionRetrieval.ipynb | 52 ++++++++++--------- 1 file changed, 28 insertions(+), 24 deletions(-) diff --git a/notebooks/VectorSearch_QuestionRetrieval.ipynb b/notebooks/VectorSearch_QuestionRetrieval.ipynb index b3a15d3a08..33a2f60228 100644 --- a/notebooks/VectorSearch_QuestionRetrieval.ipynb +++ b/notebooks/VectorSearch_QuestionRetrieval.ipynb @@ -89,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "id": "eb1e81c3", "metadata": {}, "outputs": [ @@ -154,7 +154,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "id": "ee4c5cc0", "metadata": {}, "outputs": [ @@ -184,7 +184,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "id": "0a1a6307", "metadata": {}, "outputs": [ @@ -249,7 +249,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 5, "id": "ad90b4be", "metadata": {}, "outputs": [ @@ -292,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 6, "id": "724dcacb", "metadata": { "scrolled": true @@ -320,7 +320,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 7, "id": "c27d4715", "metadata": {}, "outputs": [ @@ -347,7 +347,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 8, "id": "bc375518", "metadata": {}, "outputs": [ @@ -373,7 +373,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 9, "id": "ab154181", "metadata": {}, "outputs": [ @@ -399,7 +399,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 10, "id": "2d6017ed", "metadata": {}, "outputs": [ @@ -435,7 +435,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 11, "id": "f5cfb644", "metadata": {}, "outputs": [ @@ -462,7 +462,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 12, "id": "b5694d00", "metadata": {}, "outputs": [ @@ -489,7 +489,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 13, "id": "fcfc3c5b", "metadata": {}, "outputs": [ @@ -528,7 +528,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 14, "id": "50df1f43-c580-4019-949a-06bdc7185536", "metadata": {}, "outputs": [], @@ -538,7 +538,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 15, "id": "091cde52-4652-4230-af2b-75c35357f833", "metadata": {}, "outputs": [ @@ -546,21 +546,21 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 1min 23s, sys: 2min 7s, total: 3min 31s\n", - "Wall time: 4min 43s\n" + "CPU times: user 35.3 s, sys: 4.5 s, total: 39.8 s\n", + "Wall time: 2.16 s\n" ] } ], "source": [ "%%time\n", - "params = cagra.IndexParams(intermediate_graph_degree=128, graph_degree=64)\n", + "params = cagra.IndexParams(intermediate_graph_degree=32, graph_degree=16, build_algo=\"nn_descent\")\n", "cagra_index = cagra.build(params, corpus_embeddings)\n", - "search_params = cagra.SearchParams()" + "search_params = cagra.SearchParams(algo=\"multi_cta\")" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 16, "id": "df229e21-f6b6-4d6c-ad54-2724f8738934", "metadata": {}, "outputs": [], @@ -569,9 +569,12 @@ " # Encode the query using the bi-encoder and find potentially relevant passages\n", " question_embedding = bi_encoder.encode(query, convert_to_tensor=True)\n", "\n", + " start_time = time.time()\n", " hits = cagra.search(search_params, cagra_index, question_embedding[None], top_k)\n", + " end_time = time.time()\n", "\n", " # Output of top-k hits\n", + " print(\"Results (after {:.3f} seconds):\".format(end_time - start_time))\n", " print(\"Input question:\", query)\n", " for k in range(top_k):\n", " print(\"\\t{:.3f}\\t{}\".format(hits[0][0, k], passages[hits[1][0, k]]))" @@ -587,19 +590,20 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 16 µs, sys: 25 µs, total: 41 µs\n", - "Wall time: 83.7 µs\n", + "Results (after 0.005 seconds):\n", "Input question: Who was Grace Hopper?\n", "\t181.649\t['Grace Hopper', 'Hopper was born in New York, USA. Hopper graduated from Vassar College in 1928 and Yale University in 1934 with a Ph.D degree in mathematics. She joined the US Navy during the World War II in 1943. She worked on computers in the Navy for 43 years. She then worked in other private industry companies after 1949. She retired from the Navy in 1986 and died on January 1, 1992.']\n", "\t192.946\t['Leona Helmsley', 'Leona Helmsley (July 4, 1920 – August 20, 2007) was an American businesswoman. She was known for having a flamboyant personality. She had a reputation for tyrannical behavior; she was nicknamed the Queen of Mean.']\n", "\t194.951\t['Grace Hopper', 'Grace Murray Hopper (December 9 1906 – January 1 1992) was an American computer scientist and United States Navy officer.']\n", "\t202.192\t['Nellie Bly', 'Elizabeth Cochrane Seaman (born Elizabeth Jane Cochran; May 5, 1864 – January 27, 1922), better known by her pen name Nellie Bly, was an American journalist, novelist and inventor. She was a newspaper reporter, who worked at various jobs for exposing poor working conditions. Nellie Bly, also, fought for women\\'s right and was known for investigative reporting. She best known for her record-breaking trip around the world in 72 days, inspired by the adventure novel \"Around the World in Eighty Days\" by Jules Verne. In the 1880s, she went undercover as a mentally ill patient in a psychiatric hospital for ten days, with the report being made public in a book called \"\"Ten Days in a Mad-House\"\". She was added to the National Women\\'s Hall of Fame in 1998.']\n", - "\t205.038\t['Abbie Hoffman', 'Abbot Howard \"Abbie\" Hoffman (November 30, 1936 – April 12, 1989) was an American social and political activist.']\n" + "\t205.038\t['Abbie Hoffman', 'Abbot Howard \"Abbie\" Hoffman (November 30, 1936 – April 12, 1989) was an American social and political activist.']\n", + "CPU times: user 4.18 ms, sys: 3.88 ms, total: 8.07 ms\n", + "Wall time: 9.97 ms\n" ] } ], "source": [ - "%time \n", + "%%time \n", "search_raft_cagra(query=\"Who was Grace Hopper?\")" ] } @@ -620,7 +624,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.10.13" } }, "nbformat": 4, From 4a20d03af7f6181e3083bc3b65522d7f2c3b6218 Mon Sep 17 00:00:00 2001 From: rhdong Date: Mon, 8 Apr 2024 09:36:00 -0700 Subject: [PATCH 4/4] [FEA] Add support for `select_k` on CSR matrix (#2140) - This PR is one part of the feature of #1969 - Add the API of 'select_k' accepting CSR as input Authors: - James Rong (https://github.com/rhdong) Approvers: - Ben Frederickson (https://github.com/benfred) - Micka (https://github.com/lowener) - Corey J. Nolet (https://github.com/cjnolet) Authors: - rhdong (https://github.com/rhdong) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2140 --- cpp/bench/prims/CMakeLists.txt | 1 + cpp/bench/prims/sparse/select_k_csr.cu | 287 ++++++++++++ .../raft/matrix/detail/select_radix.cuh | 427 ++++++++++-------- .../raft/matrix/detail/select_warpsort.cuh | 55 ++- .../sparse/matrix/detail/select_k-ext.cuh | 67 +++ .../sparse/matrix/detail/select_k-inl.cuh | 225 +++++++++ .../raft/sparse/matrix/detail/select_k.cuh | 24 + cpp/include/raft/sparse/matrix/select_k.cuh | 87 ++++ .../matrix/detail/select_k_double_int64_t.cu | 32 ++ .../matrix/detail/select_k_double_uint32_t.cu | 34 ++ .../matrix/detail/select_k_float_int32.cu | 32 ++ .../matrix/detail/select_k_float_int64_t.cu | 32 ++ .../matrix/detail/select_k_float_uint32_t.cu | 32 ++ .../matrix/detail/select_k_half_int64_t.cu | 32 ++ .../matrix/detail/select_k_half_uint32_t.cu | 32 ++ cpp/test/CMakeLists.txt | 1 + cpp/test/sparse/select_k_csr.cu | 398 ++++++++++++++++ 17 files changed, 1600 insertions(+), 198 deletions(-) create mode 100644 cpp/bench/prims/sparse/select_k_csr.cu create mode 100644 cpp/include/raft/sparse/matrix/detail/select_k-ext.cuh create mode 100644 cpp/include/raft/sparse/matrix/detail/select_k-inl.cuh create mode 100644 cpp/include/raft/sparse/matrix/detail/select_k.cuh create mode 100644 cpp/include/raft/sparse/matrix/select_k.cuh create mode 100644 cpp/src/sparse/matrix/detail/select_k_double_int64_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_double_uint32_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int32.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_float_int64_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_float_uint32_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_half_int64_t.cu create mode 100644 cpp/src/sparse/matrix/detail/select_k_half_uint32_t.cu create mode 100644 cpp/test/sparse/select_k_csr.cu diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index 9f23c44a5c..0c5521d447 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -137,6 +137,7 @@ if(BUILD_PRIMS_BENCH) PATH bench/prims/sparse/bitmap_to_csr.cu bench/prims/sparse/convert_csr.cu + bench/prims/sparse/select_k_csr.cu bench/prims/main.cpp ) diff --git a/cpp/bench/prims/sparse/select_k_csr.cu b/cpp/bench/prims/sparse/select_k_csr.cu new file mode 100644 index 0000000000..a91e6c8514 --- /dev/null +++ b/cpp/bench/prims/sparse/select_k_csr.cu @@ -0,0 +1,287 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include + +namespace raft::bench::sparse { + +template +struct bench_param { + index_t n_rows; + index_t n_cols; + index_t top_k; + float sparsity; + bool select_min = true; + bool customized_indices = false; +}; + +template +inline auto operator<<(std::ostream& os, const bench_param& params) -> std::ostream& +{ + os << params.n_rows << "#" << params.n_cols << "#" << params.top_k << "#" << params.sparsity; + return os; +} + +template +struct SelectKCsrTest : public fixture { + SelectKCsrTest(const bench_param& p) + : fixture(true), + params(p), + handle(stream), + values_d(0, stream), + indptr_d(0, stream), + indices_d(0, stream), + customized_indices_d(0, stream), + dst_values_d(0, stream), + dst_indices_d(0, stream) + { + std::vector dense_values_h(params.n_rows * params.n_cols); + nnz = create_sparse_matrix(params.n_rows, params.n_cols, params.sparsity, dense_values_h); + + std::vector indices_h(nnz); + std::vector customized_indices_h(nnz); + std::vector indptr_h(params.n_rows + 1); + + convert_to_csr(dense_values_h, params.n_rows, params.n_cols, indices_h, indptr_h); + + std::vector dst_values_h(params.n_rows * params.top_k, static_cast(2.0f)); + std::vector dst_indices_h(params.n_rows * params.top_k, + static_cast(params.n_rows * params.n_cols * 100)); + + dst_values_d.resize(params.n_rows * params.top_k, stream); + dst_indices_d.resize(params.n_rows * params.top_k, stream); + values_d.resize(nnz, stream); + + if (nnz) { + auto blobs_values = raft::make_device_matrix(handle, 1, nnz); + auto labels = raft::make_device_vector(handle, 1); + + raft::random::make_blobs(blobs_values.data_handle(), + labels.data_handle(), + 1, + nnz, + 1, + stream, + false, + nullptr, + nullptr, + value_t(1.0), + false, + value_t(-10.0f), + value_t(10.0f), + uint64_t(2024)); + raft::copy(values_d.data(), blobs_values.data_handle(), nnz, stream); + resource::sync_stream(handle); + } + + indices_d.resize(nnz, stream); + indptr_d.resize(params.n_rows + 1, stream); + + update_device(indices_d.data(), indices_h.data(), indices_h.size(), stream); + update_device(indptr_d.data(), indptr_h.data(), indptr_h.size(), stream); + + if (params.customized_indices) { + customized_indices_d.resize(nnz, stream); + update_device(customized_indices_d.data(), + customized_indices_h.data(), + customized_indices_h.size(), + stream); + } + } + + index_t create_sparse_matrix(index_t m, index_t n, value_t sparsity, std::vector& matrix) + { + index_t total_elements = static_cast(m * n); + index_t num_ones = static_cast((total_elements * 1.0f) * sparsity); + index_t res = num_ones; + + for (index_t i = 0; i < total_elements; ++i) { + matrix[i] = false; + } + + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_int_distribution<> dis_idx(0, total_elements - 1); + + while (num_ones > 0) { + size_t index = dis_idx(gen); + if (matrix[index] == false) { + matrix[index] = true; + num_ones--; + } + } + return res; + } + + void convert_to_csr(std::vector& matrix, + index_t rows, + index_t cols, + std::vector& indices, + std::vector& indptr) + { + index_t offset_indptr = 0; + index_t offset_values = 0; + indptr[offset_indptr++] = 0; + + for (index_t i = 0; i < rows; ++i) { + for (index_t j = 0; j < cols; ++j) { + if (matrix[i * cols + j]) { + indices[offset_values] = static_cast(j); + offset_values++; + } + } + indptr[offset_indptr++] = static_cast(offset_values); + } + } + + template + std::optional get_opt_var(data_t x) + { + if (params.customized_indices) { + return x; + } else { + return std::nullopt; + } + } + + void run_benchmark(::benchmark::State& state) override + { + std::ostringstream label_stream; + label_stream << params; + state.SetLabel(label_stream.str()); + + auto in_val_structure = raft::make_device_compressed_structure_view( + indptr_d.data(), + indices_d.data(), + params.n_rows, + params.n_cols, + static_cast(indices_d.size())); + + auto in_val = + raft::make_device_csr_matrix_view(values_d.data(), in_val_structure); + + std::optional> in_idx; + + in_idx = get_opt_var( + raft::make_device_vector_view(customized_indices_d.data(), nnz)); + + auto out_val = raft::make_device_matrix_view( + dst_values_d.data(), params.n_rows, params.top_k); + auto out_idx = raft::make_device_matrix_view( + dst_indices_d.data(), params.n_rows, params.top_k); + + raft::sparse::matrix::select_k(handle, in_val, in_idx, out_val, out_idx, params.select_min); + resource::sync_stream(handle); + loop_on_state(state, [this, &in_val, &in_idx, &out_val, &out_idx]() { + raft::sparse::matrix::select_k( + handle, in_val, in_idx, out_val, out_idx, params.select_min, false); + resource::sync_stream(handle); + }); + } + + protected: + const raft::device_resources handle; + + bench_param params; + index_t nnz; + + rmm::device_uvector values_d; + rmm::device_uvector indptr_d; + rmm::device_uvector indices_d; + rmm::device_uvector customized_indices_d; + + rmm::device_uvector dst_values_d; + rmm::device_uvector dst_indices_d; +}; // struct SelectKCsrTest + +template +const std::vector> getInputs() +{ + std::vector> param_vec; + struct TestParams { + index_t m; + index_t n; + index_t k; + }; + + const std::vector params_group{ + {20000, 500, 1}, {20000, 500, 2}, {20000, 500, 4}, {20000, 500, 8}, + {20000, 500, 16}, {20000, 500, 32}, {20000, 500, 64}, {20000, 500, 128}, + {20000, 500, 256}, + + {1000, 10000, 1}, {1000, 10000, 2}, {1000, 10000, 4}, {1000, 10000, 8}, + {1000, 10000, 16}, {1000, 10000, 32}, {1000, 10000, 64}, {1000, 10000, 128}, + {1000, 10000, 256}, + + {100, 100000, 1}, {100, 100000, 2}, {100, 100000, 4}, {100, 100000, 8}, + {100, 100000, 16}, {100, 100000, 32}, {100, 100000, 64}, {100, 100000, 128}, + {100, 100000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 2}, {10, 1000000, 4}, {10, 1000000, 8}, + {10, 1000000, 16}, {10, 1000000, 32}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}, + + {10, 1000000, 1}, {10, 1000000, 16}, {10, 1000000, 64}, {10, 1000000, 128}, + {10, 1000000, 256}, {1000, 10000, 1}, {1000, 10000, 16}, {1000, 10000, 64}, + {1000, 10000, 128}, {1000, 10000, 256}}; + + param_vec.reserve(params_group.size()); + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.1})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.2})); + } + for (TestParams params : params_group) { + param_vec.push_back(bench_param({params.m, params.n, params.k, 0.5})); + } + return param_vec; +} + +RAFT_BENCH_REGISTER((SelectKCsrTest), "", getInputs()); + +} // namespace raft::bench::sparse diff --git a/cpp/include/raft/matrix/detail/select_radix.cuh b/cpp/include/raft/matrix/detail/select_radix.cuh index 36a346fda3..83d4845c31 100644 --- a/cpp/include/raft/matrix/detail/select_radix.cuh +++ b/cpp/include/raft/matrix/detail/select_radix.cuh @@ -442,14 +442,76 @@ _RAFT_DEVICE void last_filter(const T* in_buf, } } -template +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + in_buf = in; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + in_buf = in; + in_idx_buf = in_idx; + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + in_buf = reinterpret_cast(bufs); + in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + out_buf = const_cast(in_buf + buf_len); + out_idx_buf = const_cast(in_idx_buf + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + in_buf = out_buf + buf_len; + in_idx_buf = out_idx_buf + buf_len; + } +} + +template +_RAFT_DEVICE void set_buf_pointers(const T* in, + const IdxT* in_idx, + char* bufs, + IdxT buf_len, + const int pass, + const T*& out_buf, + const IdxT*& out_idx_buf) +{ + // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 + if (pass == 0) { + out_buf = nullptr; + out_idx_buf = nullptr; + } else if (pass == 1) { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } else if (pass % 2 == 0) { + out_buf = const_cast(reinterpret_cast(bufs) + buf_len); + out_idx_buf = + const_cast(reinterpret_cast(bufs + sizeof(T) * 2 * buf_len) + buf_len); + } else { + out_buf = reinterpret_cast(bufs); + out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); + } +} + +template RAFT_KERNEL last_filter_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, const IdxT len, + const IdxT* len_i, const IdxT k, Counter* counters, const bool select_min) @@ -458,22 +520,31 @@ RAFT_KERNEL last_filter_kernel(const T* in, Counter* counter = counters + batch_id; IdxT previous_len = counter->previous_len; + if (previous_len == 0) { return; } + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + const IdxT buf_len = calc_buf_len(len); - if (previous_len > buf_len || in_buf == in) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; - } - out += batch_id * k; - out_idx += batch_id * k; + + const T* in_buf = nullptr; + const IdxT* in_idx_buf = nullptr; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); constexpr int pass = calc_num_passes() - 1; constexpr int start_bit = calc_start_bit(pass); + set_buf_pointers(in + l_offset, in_idx + l_offset, bufs, buf_len, pass, in_buf, in_idx_buf); + + if (previous_len > buf_len || in_buf == in + l_offset) { + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; + } + out += batch_id * k; + out_idx += batch_id * k; + const auto kth_value_bits = counter->kth_value_bits; const IdxT num_of_kth_needed = counter->k; IdxT* p_out_cnt = &counter->out_cnt; @@ -510,6 +581,29 @@ RAFT_KERNEL last_filter_kernel(const T* in, f); } +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_val( + T* dest, const T* src, S len, IdxT k, const bool select_min) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + const T default_val = select_min ? upper_bound() : lower_bound(); + for (S i = idx; i < k; i += stride) { + dest[i] = i < len ? src[i] : default_val; + } +} + +template +_RAFT_DEVICE _RAFT_FORCEINLINE void copy_in_idx(T* dest, const T* src, S len) +{ + S idx = S(threadIdx.x); + S stride = S(blockDim.x); + + for (S i = idx; i < len; i += stride) { + dest[i] = src ? src[i] : i; + } +} + /** * * It is expected to call this kernel multiple times (passes), in each pass we process a radix, @@ -545,13 +639,16 @@ RAFT_KERNEL last_filter_kernel(const T* in, * rather than from `in_buf`. The benefit is that we can save the cost of writing candidates and * their indices. */ -template +template RAFT_KERNEL radix_kernel(const T* in, const IdxT* in_idx, - const T* in_buf, - const IdxT* in_idx_buf, - T* out_buf, - IdxT* out_idx_buf, + char* bufs, + size_t offset, T* out, IdxT* out_idx, Counter* counters, @@ -567,21 +664,38 @@ RAFT_KERNEL radix_kernel(const T* in, IdxT current_k; IdxT previous_len; IdxT current_len; + + const IdxT l_len = len_or_indptr ? len : (len_i[batch_id + 1] - len_i[batch_id]); + const IdxT l_offset = len_or_indptr ? (offset + batch_id) * len : len_i[batch_id]; + if (pass == 0) { current_k = k; - previous_len = len; + previous_len = l_len; // Need to do this so setting counter->previous_len for the next pass is correct. // This value is meaningless for pass 0, but it's fine because pass 0 won't be the // last pass in this implementation so pass 0 won't hit the "if (pass == // num_passes - 1)" branch. // Maybe it's better to reload counter->previous_len and use it rather than // current_len in last_filter() - current_len = len; + current_len = l_len; } else { current_k = counter->k; current_len = counter->len; previous_len = counter->previous_len; } + if constexpr (!len_or_indptr) { + if (pass == 0 && l_len <= k) { + copy_in_val(out + batch_id * k, in + l_offset, l_len, k, select_min); + copy_in_idx(out_idx + batch_id * k, (in_idx ? (in_idx + l_offset) : nullptr), l_len); + if (threadIdx.x == 0) { + counter->previous_len = 0; + counter->len = 0; + } + __syncthreads(); + return; + } + } + if (current_len == 0) { return; } // When k=len, early_stop will be true at pass 0. It means filter_and_histogram() should handle @@ -590,20 +704,33 @@ RAFT_KERNEL radix_kernel(const T* in, const bool early_stop = (current_len == current_k); const IdxT buf_len = calc_buf_len(len); + const T* in_buf; + const IdxT* in_idx_buf; + T* out_buf; + IdxT* out_idx_buf; + bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + + set_buf_pointers(in + l_offset, + (in_idx ? (in_idx + l_offset) : nullptr), + bufs, + buf_len, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + // "previous_len > buf_len" means previous pass skips writing buffer if (pass == 0 || pass == 1 || previous_len > buf_len) { - in_buf = in + batch_id * len; - in_idx_buf = in_idx ? (in_idx + batch_id * len) : nullptr; - previous_len = len; - } else { - in_buf += batch_id * buf_len; - in_idx_buf += batch_id * buf_len; + in_buf = in + l_offset; + in_idx_buf = in_idx ? (in_idx + l_offset) : nullptr; + previous_len = l_len; } // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -611,9 +738,6 @@ RAFT_KERNEL radix_kernel(const T* in, if (pass == 0 || current_len > buf_len) { out_buf = nullptr; out_idx_buf = nullptr; - } else { - out_buf += batch_id * buf_len; - out_idx_buf += batch_id * buf_len; } out += batch_id * k; out_idx += batch_id * k; @@ -640,7 +764,6 @@ RAFT_KERNEL radix_kernel(const T* in, unsigned int finished = atomicInc(&counter->finished_block_cnt, gridDim.x - 1); isLastBlock = (finished == (gridDim.x - 1)); } - if (__syncthreads_or(isLastBlock)) { if (early_stop) { if (threadIdx.x == 0) { @@ -676,7 +799,7 @@ RAFT_KERNEL radix_kernel(const T* in, out_idx_buf ? out_idx_buf : in_idx_buf, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, counter, select_min, @@ -726,7 +849,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) int active_blocks; RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &active_blocks, radix_kernel, BlockSize, 0)); + &active_blocks, radix_kernel, BlockSize, 0)); active_blocks *= sm_cnt; IdxT best_num_blocks = 0; @@ -757,78 +880,7 @@ unsigned calc_grid_dim(int batch_size, IdxT len, int sm_cnt) return best_num_blocks; } -template -_RAFT_HOST void set_buf_pointers(const T* in, - const IdxT* in_idx, - T* buf1, - IdxT* idx_buf1, - T* buf2, - IdxT* idx_buf2, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = buf1; - out_idx_buf = idx_buf1; - } else if (pass % 2 == 0) { - in_buf = buf1; - in_idx_buf = idx_buf1; - out_buf = buf2; - out_idx_buf = idx_buf2; - } else { - in_buf = buf2; - in_idx_buf = idx_buf2; - out_buf = buf1; - out_idx_buf = idx_buf1; - } -} - -template -_RAFT_DEVICE void set_buf_pointers(const T* in, - const IdxT* in_idx, - char* bufs, - IdxT buf_len, - int pass, - const T*& in_buf, - const IdxT*& in_idx_buf, - T*& out_buf, - IdxT*& out_idx_buf) -{ - // bufs consists of 4 pieces in order: buf1, buf2, idx_buf1, idx_buf2 - if (pass == 0) { - in_buf = in; - in_idx_buf = nullptr; - out_buf = nullptr; - out_idx_buf = nullptr; - } else if (pass == 1) { - in_buf = in; - in_idx_buf = in_idx; - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - } else if (pass % 2 == 0) { - in_buf = reinterpret_cast(bufs); - in_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - out_buf = const_cast(in_buf + buf_len); - out_idx_buf = const_cast(in_idx_buf + buf_len); - } else { - out_buf = reinterpret_cast(bufs); - out_idx_buf = reinterpret_cast(bufs + sizeof(T) * 2 * buf_len); - in_buf = out_buf + buf_len; - in_idx_buf = out_idx_buf + buf_len; - } -} - -template +template void radix_topk(const T* in, const IdxT* in_idx, int batch_size, @@ -850,7 +902,7 @@ void radix_topk(const T* in, if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - auto kernel = radix_kernel; + auto kernel = radix_kernel; const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, false); if (max_chunk_size != static_cast(batch_size)) { @@ -862,55 +914,33 @@ void radix_topk(const T* in, rmm::device_uvector> counters(max_chunk_size, stream, mr); rmm::device_uvector histograms(max_chunk_size * num_buckets, stream, mr); - rmm::device_uvector buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf1(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector buf2(max_chunk_size * buf_len, stream, mr); - rmm::device_uvector idx_buf2(max_chunk_size * buf_len, stream, mr); + + rmm::device_uvector bufs( + max_chunk_size * buf_len * 2 * (sizeof(T) + sizeof(IdxT)), stream, mr); for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); RAFT_CUDA_TRY( cudaMemsetAsync(counters.data(), 0, counters.size() * sizeof(Counter), stream)); RAFT_CUDA_TRY(cudaMemsetAsync(histograms.data(), 0, histograms.size() * sizeof(IdxT), stream)); - auto kernel = radix_kernel; + auto kernel = radix_kernel; - const T* chunk_in = in + offset * len; - const IdxT* chunk_in_idx = in_idx ? (in_idx + offset * len) : nullptr; - T* chunk_out = out + offset * k; - IdxT* chunk_out_idx = out_idx + offset * k; - const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - - const T* in_buf = nullptr; - const IdxT* in_idx_buf = nullptr; - T* out_buf = nullptr; - IdxT* out_idx_buf = nullptr; + T* chunk_out = out + offset * k; + IdxT* chunk_out_idx = out_idx + offset * k; + const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; dim3 blocks(grid_dim, chunk_size); constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { - set_buf_pointers(chunk_in, - chunk_in_idx, - buf1.data(), - idx_buf1.data(), - buf2.data(), - idx_buf2.data(), - pass, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf); - if (fused_last_filter && pass == num_passes - 1) { - kernel = radix_kernel; + kernel = radix_kernel; } - kernel<<>>(chunk_in, - chunk_in_idx, - in_buf, - in_idx_buf, - out_buf, - out_idx_buf, + kernel<<>>(in, + in_idx, + bufs.data(), + offset, chunk_out, chunk_out_idx, counters.data(), @@ -924,16 +954,18 @@ void radix_topk(const T* in, } if (!fused_last_filter) { - last_filter_kernel<<>>(chunk_in, - chunk_in_idx, - out_buf, - out_idx_buf, - chunk_out, - chunk_out_idx, - len, - k, - counters.data(), - select_min); + last_filter_kernel + <<>>(in, + in_idx, + bufs.data(), + offset, + chunk_out, + chunk_out_idx, + len, + chunk_len_i, + k, + counters.data(), + select_min); RAFT_CUDA_TRY(cudaPeekAtLastError()); } } @@ -1015,7 +1047,7 @@ _RAFT_DEVICE void filter_and_histogram_for_one_block(const T* in_buf, } } -template +template RAFT_KERNEL radix_topk_one_block_kernel(const T* in, const IdxT* in_idx, const IdxT len, @@ -1024,30 +1056,48 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, T* out, IdxT* out_idx, const bool select_min, - char* bufs) + char* bufs, + size_t offset) { constexpr int num_buckets = calc_num_buckets(); __shared__ Counter counter; __shared__ IdxT histogram[num_buckets]; + const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow + + IdxT l_len = len; + IdxT l_offset = (offset + batch_id) * len; + if constexpr (!len_or_indptr) { + l_offset = len_i[batch_id]; + l_len = len_i[batch_id + 1] - l_offset; + } + if (threadIdx.x == 0) { counter.k = k; - counter.len = len; - counter.previous_len = len; + counter.len = l_len; + counter.previous_len = l_len; counter.kth_value_bits = 0; counter.out_cnt = 0; counter.out_back_cnt = 0; } __syncthreads(); - const size_t batch_id = blockIdx.x; // size_t to avoid multiplication overflow - in += batch_id * len; - if (in_idx) { in_idx += batch_id * len; } + in += l_offset; + if (in_idx) { in_idx += l_offset; } out += batch_id * k; out_idx += batch_id * k; const IdxT buf_len = calc_buf_len(len); bufs += batch_id * buf_len * 2 * (sizeof(T) + sizeof(IdxT)); + if constexpr (!len_or_indptr) { + if (l_len <= k) { + copy_in_val(out, in, l_len, k, select_min); + copy_in_idx(out_idx, in_idx, l_len); + __syncthreads(); + return; + } + } + constexpr int num_passes = calc_num_passes(); for (int pass = 0; pass < num_passes; ++pass) { const T* in_buf; @@ -1073,7 +1123,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // in case we have individual len for each query defined we want to make sure // that we only iterate valid elements. if (len_i != nullptr) { - const IdxT max_len = max(len_i[batch_id], k); + const IdxT max_len = max(l_len, k); if (max_len < previous_len) previous_len = max_len; } @@ -1102,7 +1152,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, out_buf ? out_idx_buf : in_idx, out, out_idx, - out_buf ? current_len : len, + out_buf ? current_len : l_len, k, &counter, select_min, @@ -1117,7 +1167,7 @@ RAFT_KERNEL radix_topk_one_block_kernel(const T* in, // counters and global histograms, can be kept in shared memory and cheap sync operations can be // used. It's used when len is relatively small or when the number of blocks per row calculated by // `calc_grid_dim()` is 1. -template +template void radix_topk_one_block(const T* in, const IdxT* in_idx, int batch_size, @@ -1133,7 +1183,7 @@ void radix_topk_one_block(const T* in, { static_assert(calc_num_passes() > 1); - auto kernel = radix_topk_one_block_kernel; + auto kernel = radix_topk_one_block_kernel; const IdxT buf_len = calc_buf_len(len); const size_t max_chunk_size = calc_chunk_size(batch_size, len, sm_cnt, kernel, true); @@ -1144,15 +1194,16 @@ void radix_topk_one_block(const T* in, for (size_t offset = 0; offset < static_cast(batch_size); offset += max_chunk_size) { int chunk_size = std::min(max_chunk_size, batch_size - offset); const IdxT* chunk_len_i = len_i ? (len_i + offset) : nullptr; - kernel<<>>(in + offset * len, - in_idx ? (in_idx + offset * len) : nullptr, + kernel<<>>(in, + in_idx, len, chunk_len_i, k, out + offset * k, out_idx + offset * k, select_min, - bufs.data()); + bufs.data(), + offset); } } @@ -1182,6 +1233,10 @@ void radix_topk_one_block(const T* in, * it affects the number of passes and number of buckets. * @tparam BlockSize * Number of threads in a kernel thread block. + * @tparam len_or_indptr + * Flag to interpret `len_i` as either direct row lengths (true) or CSR format + * index pointers (false). When true, each `len_i` element denotes the length of a row. When + * false, `len_i` represents the index pointers for a CSR matrix with shape of `batch_size + 1`. * * @param[in] res container of reusable resources * @param[in] in @@ -1212,9 +1267,12 @@ void radix_topk_one_block(const T* in, * same. That is, when the value range of input data is narrow. In such case, there could be a * large number of inputs for the last filter, hence using multiple thread blocks is beneficial. * @param len_i - * optional array of size (batch_size) providing lengths for each individual row + * Optional array used differently based on `len_or_indptr`: + * When `len_or_indptr` is true, `len_i` presents the lengths of each row, which is `batch_size`. + * When `len_or_indptr` is false, `len_i` works like a indptr for a CSR matrix. The length of each + * row would be (`len_i[row_id + 1] - len_i[row_id]`). `len_i` size is `batch_size + 1`. */ -template +template void select_k(raft::resources const& res, const T* in, const IdxT* in_idx, @@ -1227,9 +1285,12 @@ void select_k(raft::resources const& res, bool fused_last_filter, const IdxT* len_i) { + RAFT_EXPECTS(!(!len_or_indptr && (len_i == nullptr)), + "When `len_or_indptr` is false, `len_i` must not be nullptr!"); + auto stream = resource::get_cuda_stream(res); auto mr = resource::get_workspace_resource(res); - if (k == len) { + if (k == len && len_or_indptr) { RAFT_CUDA_TRY( cudaMemcpyAsync(out, in, sizeof(T) * batch_size * len, cudaMemcpyDeviceToDevice, stream)); if (in_idx) { @@ -1248,29 +1309,29 @@ void select_k(raft::resources const& res, constexpr int items_per_thread = 32; if (len <= BlockSize * items_per_thread) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { unsigned grid_dim = impl::calc_grid_dim(batch_size, len, sm_cnt); if (grid_dim == 1) { - impl::radix_topk_one_block( + impl::radix_topk_one_block( in, in_idx, batch_size, len, k, out, out_idx, select_min, len_i, sm_cnt, stream, mr); } else { - impl::radix_topk(in, - in_idx, - batch_size, - len, - k, - out, - out_idx, - select_min, - fused_last_filter, - len_i, - grid_dim, - sm_cnt, - stream, - mr); + impl::radix_topk(in, + in_idx, + batch_size, + len, + k, + out, + out_idx, + select_min, + fused_last_filter, + len_i, + grid_dim, + sm_cnt, + stream, + mr); } } } diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index 572558153d..2cb32585d5 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -754,22 +754,32 @@ template