diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index 68d9dda1a3..d511704682 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -100,6 +100,19 @@ inline void launcher(const raft::handle_t& handle, const ML::UMAPParams* params, cudaStream_t stream) { + raft::resources tmp_handle(handle); + auto mr = raft::resource::get_workspace_resource(tmp_handle); + size_t free_size = raft::resource::get_workspace_free_bytes(tmp_handle); + + double factor = 4.0; + size_t index_batch_size = inputsA.n; + size_t query_batch_size = inputsB.n; + size_t requirements = factor * sizeof(float) * index_batch_size * query_batch_size; + + if (requirements > free_size) { + index_batch_size = free_size / (query_batch_size * factor * sizeof(float)); + } + raft::sparse::selection::brute_force_knn(inputsA.indptr, inputsA.indices, inputsA.data, @@ -115,9 +128,9 @@ inline void launcher(const raft::handle_t& handle, out.knn_indices, out.knn_dists, n_neighbors, - handle, - ML::Sparse::DEFAULT_BATCH_SIZE, - ML::Sparse::DEFAULT_BATCH_SIZE, + tmp_handle, + index_batch_size, + query_batch_size, params->metric, params->p); }