From 6c42aee82dc2d4fa30529fbba3b96b013b332624 Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <4302519+difyrrwrzd@users.noreply.github.com> Date: Wed, 12 Jun 2024 18:17:18 +0200 Subject: [PATCH] Scaling workspace resources (#181) Use raft's large workspace resource for large temporary allocations during ANN index build. This is the port of https://github.com/rapidsai/raft/pull/2194, which didn't make into raft before the algorithms were ported to cuVS. Authors: - Artem M. Chirkin (https://github.com/achirkin) Approvers: - Tamas Bela Feher (https://github.com/tfeher) URL: https://github.com/rapidsai/cuvs/pull/181 --- .../neighbors/detail/cagra/cagra_build.cuh | 59 +++++++---- cpp/src/neighbors/detail/cagra/graph_core.cuh | 30 +++--- cpp/src/neighbors/detail/cagra/utils.hpp | 7 +- cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh | 13 ++- .../neighbors/ivf_flat/ivf_flat_search.cuh | 23 ++--- cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh | 97 ++++++++++++------- cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh | 5 +- 7 files changed, 149 insertions(+), 85 deletions(-) diff --git a/cpp/src/neighbors/detail/cagra/cagra_build.cuh b/cpp/src/neighbors/detail/cagra/cagra_build.cuh index 43e7200..a3e591b 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_build.cuh @@ -34,6 +34,7 @@ // TODO: Fixme- this needs to be migrated #include "../../ivf_pq/ivf_pq_build.cuh" +#include "../../ivf_pq/ivf_pq_search.cuh" #include "../../nn_descent.cuh" // TODO: This shouldn't be calling spatial/knn APIs @@ -162,42 +163,64 @@ void build_knn_graph( // search top (k + 1) neighbors // - const auto top_k = node_degree + 1; - uint32_t gpu_top_k = node_degree * pq.refinement_rate; - gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); - const auto num_queries = dataset.extent(0); - const auto max_batch_size = 1024; + const auto top_k = node_degree + 1; + uint32_t gpu_top_k = node_degree * pq.refinement_rate; + gpu_top_k = std::min(std::max(gpu_top_k, top_k), dataset.extent(0)); + const auto num_queries = dataset.extent(0); + + // Use the same maximum batch size as the ivf_pq::search to avoid allocating more than needed. + using cuvs::neighbors::ivf_pq::detail::kMaxQueries; + // Heuristic: the build_knn_graph code should use only a fraction of the workspace memory; the + // rest should be used by the ivf_pq::search. Here we say that the workspace size should be a good + // multiple of what is required for the I/O batching below. + constexpr size_t kMinWorkspaceRatio = 5; + auto desired_workspace_size = kMaxQueries * kMinWorkspaceRatio * + (sizeof(DataT) * dataset.extent(1) // queries (dataset batch) + + sizeof(float) * gpu_top_k // distances + + sizeof(int64_t) * gpu_top_k // neighbors + + sizeof(float) * top_k // refined_distances + + sizeof(int64_t) * top_k // refined_neighbors + ); + + // If the workspace is smaller than desired, put the I/O buffers into the large workspace. + rmm::device_async_resource_ref workspace_mr = + desired_workspace_size <= raft::resource::get_workspace_free_bytes(res) + ? raft::resource::get_workspace_resource(res) + : raft::resource::get_large_workspace_resource(res); + RAFT_LOG_DEBUG( "IVF-PQ search node_degree: %d, top_k: %d, gpu_top_k: %d, max_batch_size:: %d, n_probes: %u", node_degree, top_k, gpu_top_k, - max_batch_size, + kMaxQueries, pq.search_params.n_probes); - auto distances = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto neighbors = raft::make_device_matrix(res, max_batch_size, gpu_top_k); - auto refined_distances = raft::make_device_matrix(res, max_batch_size, top_k); - auto refined_neighbors = raft::make_device_matrix(res, max_batch_size, top_k); - auto neighbors_host = raft::make_host_matrix(max_batch_size, gpu_top_k); - auto queries_host = raft::make_host_matrix(max_batch_size, dataset.extent(1)); - auto refined_neighbors_host = raft::make_host_matrix(max_batch_size, top_k); - auto refined_distances_host = raft::make_host_matrix(max_batch_size, top_k); + auto distances = raft::make_device_mdarray( + res, workspace_mr, raft::make_extents(kMaxQueries, gpu_top_k)); + auto neighbors = raft::make_device_mdarray( + res, workspace_mr, raft::make_extents(kMaxQueries, gpu_top_k)); + auto refined_distances = raft::make_device_mdarray( + res, workspace_mr, raft::make_extents(kMaxQueries, top_k)); + auto refined_neighbors = raft::make_device_mdarray( + res, workspace_mr, raft::make_extents(kMaxQueries, top_k)); + auto neighbors_host = raft::make_host_matrix(kMaxQueries, gpu_top_k); + auto queries_host = raft::make_host_matrix(kMaxQueries, dataset.extent(1)); + auto refined_neighbors_host = raft::make_host_matrix(kMaxQueries, top_k); + auto refined_distances_host = raft::make_host_matrix(kMaxQueries, top_k); // TODO(tfeher): batched search with multiple GPUs std::size_t num_self_included = 0; bool first = true; const auto start_clock = std::chrono::system_clock::now(); - rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(res); - cuvs::spatial::knn::detail::utils::batch_load_iterator vec_batches( dataset.data_handle(), dataset.extent(0), dataset.extent(1), - (int64_t)max_batch_size, + static_cast(kMaxQueries), raft::resource::get_cuda_stream(res), - device_memory); + workspace_mr); size_t next_report_offset = 0; size_t d_report_offset = dataset.extent(0) / 100; // Report progress in 1% steps. diff --git a/cpp/src/neighbors/detail/cagra/graph_core.cuh b/cpp/src/neighbors/detail/cagra/graph_core.cuh index 2e90eed..e10b85d 100644 --- a/cpp/src/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/src/neighbors/detail/cagra/graph_core.cuh @@ -41,8 +41,7 @@ #include #include -namespace cuvs::neighbors::cagra::detail { -namespace graph { +namespace cuvs::neighbors::cagra::detail::graph { // unnamed namespace to avoid multiple definition error namespace { @@ -251,7 +250,10 @@ void sort_knn_graph( const uint32_t input_graph_degree = knn_graph.extent(1); IdxT* const input_graph_ptr = knn_graph.data_handle(); - auto d_input_graph = raft::make_device_matrix(res, graph_size, input_graph_degree); + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); + + auto d_input_graph = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, input_graph_degree)); // // Sorting kNN graph @@ -259,7 +261,8 @@ void sort_knn_graph( const double time_sort_start = cur_time(); RAFT_LOG_DEBUG("# Sorting kNN Graph on GPUs "); - auto d_dataset = raft::make_device_matrix(res, dataset_size, dataset_dim); + auto d_dataset = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(dataset_size, dataset_dim)); raft::copy(d_dataset.data_handle(), dataset_ptr, dataset_size * dataset_dim, @@ -332,6 +335,7 @@ void optimize( { RAFT_LOG_DEBUG( "# Pruning kNN graph (size=%lu, degree=%lu)\n", knn_graph.extent(0), knn_graph.extent(1)); + auto large_tmp_mr = raft::resource::get_large_workspace_resource(res); RAFT_EXPECTS(knn_graph.extent(0) == new_graph.extent(0), "Each input array is expected to have the same number of rows"); @@ -347,15 +351,16 @@ void optimize( // // Prune kNN graph // - auto d_detour_count = - raft::make_device_matrix(res, graph_size, input_graph_degree); + auto d_detour_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size, input_graph_degree)); RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), 0xff, graph_size * input_graph_degree * sizeof(uint8_t), raft::resource::get_cuda_stream(res))); - auto d_num_no_detour_edges = raft::make_device_vector(res, graph_size); + auto d_num_no_detour_edges = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); RAFT_CUDA_TRY(cudaMemsetAsync(d_num_no_detour_edges.data_handle(), 0x00, graph_size * sizeof(uint32_t), @@ -475,14 +480,16 @@ void optimize( graph_size * output_graph_degree * sizeof(IdxT), raft::resource::get_cuda_stream(res))); - auto d_rev_graph_count = raft::make_device_vector(res, graph_size); + auto d_rev_graph_count = raft::make_device_mdarray( + res, large_tmp_mr, raft::make_extents(graph_size)); RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph_count.data_handle(), 0x00, graph_size * sizeof(uint32_t), raft::resource::get_cuda_stream(res))); - auto dest_nodes = raft::make_host_vector(graph_size); - auto d_dest_nodes = raft::make_device_vector(res, graph_size); + auto dest_nodes = raft::make_host_vector(graph_size); + auto d_dest_nodes = + raft::make_device_mdarray(res, large_tmp_mr, raft::make_extents(graph_size)); for (uint64_t k = 0; k < output_graph_degree; k++) { #pragma omp parallel for @@ -578,5 +585,4 @@ void optimize( } } -} // namespace graph -} // namespace cuvs::neighbors::cagra::detail +} // namespace cuvs::neighbors::cagra::detail::graph diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 395ec15..8ce20ec 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -186,8 +186,11 @@ class device_matrix_view_from_host { device_ptr = reinterpret_cast(attr.devicePointer); if (device_ptr == NULL) { // allocate memory and copy over - device_mem_.emplace( - raft::make_device_matrix(res, host_view.extent(0), host_view.extent(1))); + // NB: We use the temporary "large" workspace resource here; this structure is supposed to + // live on stack and not returned to a user. + // The user may opt to set this resource to managed memory to allow large allocations. + device_mem_.emplace(raft::make_device_mdarray( + res, raft::resource::get_large_workspace_resource(res), host_view.extents())); raft::copy(device_mem_->data_handle(), host_view.data_handle(), host_view.extent(0) * host_view.extent(1), diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh index 811f9c2..ff1f303 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_build.cuh @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -184,7 +185,8 @@ void extend(raft::resources const& handle, RAFT_EXPECTS(new_indices != nullptr || index->size() == 0, "You must pass data indices when the index is non-empty."); - auto new_labels = raft::make_device_vector(handle, n_rows); + auto new_labels = raft::make_device_mdarray( + handle, raft::resource::get_large_workspace_resource(handle), raft::make_extents(n_rows)); cuvs::cluster::kmeans::balanced_params kmeans_params; kmeans_params.metric = index->metric(); auto orig_centroids_view = @@ -215,7 +217,8 @@ void extend(raft::resources const& handle, } auto* list_sizes_ptr = index->list_sizes().data_handle(); - auto old_list_sizes_dev = raft::make_device_vector(handle, n_lists); + auto old_list_sizes_dev = raft::make_device_mdarray( + handle, raft::resource::get_workspace_resource(handle), raft::make_extents(n_lists)); raft::copy(old_list_sizes_dev.data_handle(), list_sizes_ptr, n_lists, stream); // Calculate the centers and sizes on the new data, starting from the original values @@ -371,7 +374,8 @@ inline auto build(raft::resources const& handle, auto trainset_ratio = std::max( 1, n_rows / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); auto n_rows_train = n_rows / trainset_ratio; - rmm::device_uvector trainset(n_rows_train * index.dim(), stream); + rmm::device_uvector trainset( + n_rows_train * index.dim(), stream, raft::resource::get_large_workspace_resource(handle)); // TODO: a proper sampling RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), sizeof(T) * index.dim(), @@ -431,7 +435,8 @@ inline void fill_refinement_index(raft::resources const& handle, common::nvtx::range fun_scope( "ivf_flat::fill_refinement_index(%zu, %u)", size_t(n_queries)); - rmm::device_uvector new_labels(n_queries * n_candidates, stream); + rmm::device_uvector new_labels( + n_queries * n_candidates, stream, raft::resource::get_workspace_resource(handle)); auto new_labels_view = raft::make_device_vector_view(new_labels.data(), n_queries * n_candidates); raft::linalg::map_offset( diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh index d5efdeb..43111a7 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_search.cuh @@ -280,17 +280,15 @@ void search_impl(raft::resources const& handle, template -inline void search_with_filtering( - raft::resources const& handle, - const search_params& params, - const index& index, - const T* queries, - uint32_t n_queries, - uint32_t k, - IdxT* neighbors, - float* distances, - rmm::device_async_resource_ref mr = rmm::mr::get_current_device_resource(), - IvfSampleFilterT sample_filter = IvfSampleFilterT()) +inline void search_with_filtering(raft::resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + IvfSampleFilterT sample_filter = IvfSampleFilterT()) { common::nvtx::range fun_scope( "ivf_flat::search(k = %u, n_queries = %u, dim = %zu)", k, n_queries, index.dim()); @@ -335,7 +333,7 @@ inline void search_with_filtering( cuvs::distance::is_min_close(index.metric()), neighbors + offset_q * k, distances + offset_q * k, - mr, + raft::resource::get_workspace_resource(handle), sample_filter); } } @@ -367,7 +365,6 @@ void search_with_filtering(raft::resources const& handle, static_cast(neighbors.extent(1)), neighbors.data_handle(), distances.data_handle(), - raft::resource::get_workspace_resource(handle), sample_filter); } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh index 804f25d..8823bbb 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_build.cuh @@ -64,6 +64,7 @@ #include namespace cuvs::neighbors::ivf_pq::detail { +using raft::RAFT_NAME; // TODO: this is required for RAFT_LOG_XXX messages. using namespace cuvs::spatial::knn::detail; // NOLINT using internal_extents_t = int64_t; // The default mdspan extent type used internally. @@ -184,7 +185,7 @@ void flat_compute_residuals( raft::device_matrix_view centers, // [n_lists, dim_ext] const T* dataset, // [n_rows, dim] std::variant labels, // [n_rows] - rmm::mr::device_memory_resource* device_memory) + rmm::device_async_resource_ref device_memory) { auto stream = raft::resource::get_cuda_stream(handle); auto dim = rotation_matrix.extent(1); @@ -357,8 +358,7 @@ void train_per_subset(raft::resources const& handle, size_t n_rows, const float* trainset, // [n_rows, dim] const uint32_t* labels, // [n_rows] - uint32_t kmeans_n_iters, - rmm::mr::device_memory_resource* managed_memory) + uint32_t kmeans_n_iters) { auto stream = raft::resource::get_cuda_stream(handle); auto device_memory = raft::resource::get_workspace_resource(handle); @@ -435,11 +435,13 @@ void train_per_cluster(raft::resources const& handle, size_t n_rows, const float* trainset, // [n_rows, dim] const uint32_t* labels, // [n_rows] - uint32_t kmeans_n_iters, - rmm::mr::device_memory_resource* managed_memory) + uint32_t kmeans_n_iters) { auto stream = raft::resource::get_cuda_stream(handle); auto device_memory = raft::resource::get_workspace_resource(handle); + // NB: Managed memory is used for small arrays accessed from both device and host. There's no + // performance reasoning behind this, just avoiding the boilerplate of explicit copies. + rmm::mr::managed_memory_resource managed_memory; rmm::device_uvector pq_centers_tmp(index.pq_centers().size(), stream, device_memory); rmm::device_uvector cluster_sizes(index.n_lists(), stream, managed_memory); @@ -1303,7 +1305,7 @@ void process_and_fill_codes(raft::resources const& handle, std::variant src_offset_or_indices, const uint32_t* new_labels, IdxT n_rows, - rmm::mr::device_memory_resource* mr) + rmm::device_async_resource_ref mr) { auto new_vectors_residual = raft::make_device_mdarray(handle, mr, raft::make_extents(n_rows, index.rot_dim())); @@ -1500,7 +1502,9 @@ void extend(raft::resources const& handle, std::is_same_v, "Unsupported data type"); - rmm::mr::device_memory_resource* device_memory = raft::resource::get_workspace_resource(handle); + rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); + rmm::device_async_resource_ref large_memory = + raft::resource::get_large_workspace_resource(handle); // The spec defines how the clusters look like auto spec = list_spec{ @@ -1515,12 +1519,22 @@ void extend(raft::resources const& handle, n_rows + (kIndexGroupSize - 1) * std::min(n_clusters, n_rows)); // Available device memory - size_t free_mem, total_mem; - RAFT_CUDA_TRY(cudaMemGetInfo(&free_mem, &total_mem)); - + size_t free_mem = raft::resource::get_workspace_free_bytes(handle); + + // We try to use the workspace memory by default here. + // If the workspace limit is too small, we change the resource for batch data to the + // `large_workspace_resource`, which does not have the explicit allocation limit. The user may opt + // to populate the `large_workspace_resource` memory resource with managed memory for easier + // scaling. + rmm::device_async_resource_ref labels_mr = device_memory; + rmm::device_async_resource_ref batches_mr = device_memory; + if (n_rows * (index->dim() * sizeof(T) + index->pq_dim() + sizeof(IdxT) + sizeof(uint32_t)) > + free_mem) { + labels_mr = large_memory; + } // Allocate a buffer for the new labels (classifying the new data) - rmm::device_uvector new_data_labels(n_rows, stream, device_memory); - free_mem -= sizeof(uint32_t) * n_rows; + rmm::device_uvector new_data_labels(n_rows, stream, labels_mr); + free_mem = raft::resource::get_workspace_free_bytes(handle); // Calculate the batch size for the input data if it's not accessible directly from the device constexpr size_t kReasonableMaxBatchSize = 65536; @@ -1549,13 +1563,19 @@ void extend(raft::resources const& handle, while (size_factor * max_batch_size > free_mem && max_batch_size > 128) { max_batch_size >>= 1; } - // If we're keeping the batches in device memory, update the available mem tracker. - free_mem -= size_factor * max_batch_size; + if (size_factor * max_batch_size > free_mem) { + // if that still doesn't fit, resort to the UVM + batches_mr = large_memory; + max_batch_size = kReasonableMaxBatchSize; + } else { + // If we're keeping the batches in device memory, update the available mem tracker. + free_mem -= size_factor * max_batch_size; + } } // Predict the cluster labels for the new data, in batches if necessary utils::batch_load_iterator vec_batches( - new_vectors, n_rows, index->dim(), max_batch_size, stream, device_memory); + new_vectors, n_rows, index->dim(), max_batch_size, stream, batches_mr); // Release the placeholder memory, because we don't intend to allocate any more long-living // temporary buffers before we allocate the index data. // This memory could potentially speed up UVM accesses, if any. @@ -1628,7 +1648,7 @@ void extend(raft::resources const& handle, // By this point, the index state is updated and valid except it doesn't contain the new data // Fill the extended index with the new data (possibly, in batches) utils::batch_load_iterator idx_batches( - new_indices, n_rows, 1, max_batch_size, stream, device_memory); + new_indices, n_rows, 1, max_batch_size, stream, batches_mr); for (const auto& vec_batch : vec_batches) { const auto& idx_batch = *idx_batches++; process_and_fill_codes(handle, @@ -1639,7 +1659,7 @@ void extend(raft::resources const& handle, : std::variant(IdxT(idx_batch.offset())), new_data_labels.data() + vec_batch.offset(), IdxT(vec_batch.size()), - device_memory); + batches_mr); } } @@ -1694,12 +1714,29 @@ auto build(raft::resources const& handle, size_t(n_rows) / std::max(params.kmeans_trainset_fraction * n_rows, index.n_lists())); size_t n_rows_train = n_rows / trainset_ratio; - auto* device_memory = raft::resource::get_workspace_resource(handle); - rmm::mr::managed_memory_resource managed_memory_upstream; + rmm::device_async_resource_ref device_memory = raft::resource::get_workspace_resource(handle); + + // If the trainset is small enough to comfortably fit into device memory, put it there. + // Otherwise, use the managed memory. + constexpr size_t kTolerableRatio = 4; + rmm::device_async_resource_ref big_memory_resource = + raft::resource::get_large_workspace_resource(handle); + if (sizeof(float) * n_rows_train * index.dim() * kTolerableRatio < + raft::resource::get_workspace_free_bytes(handle)) { + big_memory_resource = device_memory; + } // Besides just sampling, we transform the input dataset into floats to make it easier // to use gemm operations from cublas. - rmm::device_uvector trainset(n_rows_train * index.dim(), stream, device_memory); + rmm::device_uvector trainset(0, stream, big_memory_resource); + try { + trainset.resize(n_rows_train * index.dim(), stream); + } catch (raft::logic_error& e) { + RAFT_LOG_ERROR( + "Insufficient memory for kmeans training set allocation. Please decrease " + "kmeans_trainset_fraction, or set large_workspace_resource appropriately."); + throw; + } // TODO: a proper sampling if constexpr (std::is_same_v) { RAFT_CUDA_TRY(cudaMemcpy2DAsync(trainset.data(), @@ -1770,7 +1807,7 @@ auto build(raft::resources const& handle, handle, kmeans_params, trainset_const_view, centers_view, utils::mapping{}); // Trainset labels are needed for training PQ codebooks - rmm::device_uvector labels(n_rows_train, stream, device_memory); + rmm::device_uvector labels(n_rows_train, stream, big_memory_resource); auto centers_const_view = raft::make_device_matrix_view( cluster_centers, index.n_lists(), index.dim()); auto labels_view = @@ -1790,22 +1827,12 @@ auto build(raft::resources const& handle, // Train PQ codebooks switch (index.codebook_kind()) { case codebook_gen::PER_SUBSPACE: - train_per_subset(handle, - index, - n_rows_train, - trainset.data(), - labels.data(), - params.kmeans_n_iters, - &managed_memory_upstream); + train_per_subset( + handle, index, n_rows_train, trainset.data(), labels.data(), params.kmeans_n_iters); break; case codebook_gen::PER_CLUSTER: - train_per_cluster(handle, - index, - n_rows_train, - trainset.data(), - labels.data(), - params.kmeans_n_iters, - &managed_memory_upstream); + train_per_cluster( + handle, index, n_rows_train, trainset.data(), labels.data(), params.kmeans_n_iters); break; default: RAFT_FAIL("Unreachable code"); } diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh index e131b8f..5f812dc 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_search.cuh @@ -586,6 +586,9 @@ inline auto get_max_batch_size(raft::resources const& res, return max_batch_size; } +/** Maximum number of queries ivf_pq::search can process in one batch. */ +constexpr uint32_t kMaxQueries = 4096; + /** See raft::spatial::knn::ivf_pq::search docs */ template (std::max(n_queries, 1), 4096); + const auto max_queries = std::min(std::max(n_queries, 1), kMaxQueries); auto max_batch_size = get_max_batch_size(handle, k, n_probes, max_queries, max_samples); rmm::device_uvector float_queries(max_queries * dim_ext, stream, mr);