Skip to content

Commit

Permalink
[FEA] Use CAGRA in C++ template (#1730)
Browse files Browse the repository at this point in the history
Proposal to change the C++ template from a distance computation to a vector search application using CAGRA.

Authors:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1730
  • Loading branch information
lowener authored Aug 15, 2023
1 parent 25b6916 commit 2c85c0b
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 40 deletions.
6 changes: 5 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,11 @@ endif()
rapids_cuda_init_runtime(USE_STATIC ${CUDA_STATIC_RUNTIME})

if(NOT DISABLE_OPENMP)
find_package(OpenMP)
rapids_find_package(
OpenMP REQUIRED
BUILD_EXPORT_SET raft-exports
INSTALL_EXPORT_SET raft-exports
)
if(OPENMP_FOUND)
message(VERBOSE "RAFT: OpenMP found in ${OpenMP_CXX_INCLUDE_DIRS}")
endif()
Expand Down
2 changes: 1 addition & 1 deletion cpp/template/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -34,5 +34,5 @@ rapids_cpm_init()
include(cmake/thirdparty/get_raft.cmake)

# -------------- compile tasks ----------------- #
add_executable(TEST_RAFT src/test_distance.cu)
add_executable(TEST_RAFT src/test_vector_search.cu)
target_link_libraries(TEST_RAFT PRIVATE raft::raft raft::compiled)
38 changes: 0 additions & 38 deletions cpp/template/src/test_distance.cu

This file was deleted.

59 changes: 59 additions & 0 deletions cpp/template/src/test_vector_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
/*
* Copyright (c) 2022-2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cstdint>
#include <raft/core/device_mdarray.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/neighbors/cagra.cuh>
#include <raft/random/make_blobs.cuh>

int main()
{
using namespace raft::neighbors;
raft::device_resources dev_resources;
// Use 5 GB of pool memory
raft::resource::set_workspace_to_pool_resource(
dev_resources, std::make_optional<std::size_t>(5 * 1024 * 1024 * 1024ull));

int64_t n_samples = 50000;
int64_t n_dim = 90;
int64_t topk = 12;
int64_t n_queries = 1;

// create input and output arrays
auto input = raft::make_device_matrix<float>(dev_resources, n_samples, n_dim);
auto labels = raft::make_device_vector<int64_t>(dev_resources, n_samples);
auto queries = raft::make_device_matrix<float>(dev_resources, n_queries, n_dim);
auto neighbors = raft::make_device_matrix<int64_t>(dev_resources, n_queries, topk);
auto distances = raft::make_device_matrix<float>(dev_resources, n_queries, topk);

raft::random::make_blobs(dev_resources, input.view(), labels.view());

// use default index parameters
cagra::index_params index_params;
// create and fill the index from a [n_samples, n_dim] input
auto index = cagra::build<float, int64_t>(
dev_resources, index_params, raft::make_const_mdspan(input.view()));
// use default search parameters
cagra::search_params search_params;
// search K nearest neighbors
cagra::search<float, int64_t>(dev_resources,
search_params,
index,
raft::make_const_mdspan(queries.view()),
neighbors.view(),
distances.view());
}

0 comments on commit 2c85c0b

Please sign in to comment.