Skip to content

Commit

Permalink
Add usage example for brute_force::build (#2029)
Browse files Browse the repository at this point in the history
Add a usage example for using the brute_force index api for building and searching.

Also fix some minor compile time errors in the vector search tutorial

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2029
  • Loading branch information
benfred authored Nov 30, 2023
1 parent ed272c1 commit 04fa426
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 17 deletions.
40 changes: 40 additions & 0 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,44 @@ void fused_l2_knn(raft::resources const& handle,
/**
* @brief Build the index from the dataset for efficient search.
*
* This function builds a brute force index for the given dataset. This lets you re-use
* precalculated norms for the dataset, leading to a speedup over calling
* raft::neighbors::brute_force::knn repeatedly.
*
* Example usage:
* @code{.cpp}
* #include <raft/neighbors/brute_force.cuh>
* #include <raft/core/device_mdarray.hpp>
* #include <raft/random/make_blobs.cuh>
*
* // create a random dataset
* int n_rows = 10000;
* int n_cols = 10000;
*
* raft::device_resources res;
* auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols);
* auto labels = raft::make_device_vector<int64_t, int64_t>(res, n_rows);
*
* raft::random::make_blobs(res, dataset.view(), labels.view());
*
* // create a brute_force knn index from the dataset
* auto index = raft::neighbors::brute_force::build(res,
* raft::make_const_mdspan(dataset.view()));
*
* // Use the constructed index to search for the nearest 128 neighbors
* int k = 128;
* auto search = raft::make_const_mdspan(dataset.view());
*
* auto indices= raft::make_device_matrix<int, int64_t>(res, search.extent(0), k);
* auto distances = raft::make_device_matrix<float, int64_t>(res, search.extent(0), k);
*
* raft::neighbors::brute_force::search(res,
* index,
* search,
* indices.view(),
* distances.view());
* @endcode
*
* @tparam T data element type
*
* @param[in] res
Expand Down Expand Up @@ -330,6 +368,8 @@ index<T> build(raft::resources const& res,
/**
* @brief Brute Force search using the constructed index.
*
* See raft::neighbors::brute_force::build for a usage example
*
* @tparam T data element type
* @tparam IdxT type of the indices
*
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ namespace raft::neighbors::brute_force {
* int n_cols = 10000;
* raft::device_resources res;
* auto dataset = raft::make_device_matrix<float, int>(res, n_rows, n_cols);
* auto labels = raft::make_device_vector<float, int>(res, n_rows);
* auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols);
* auto labels = raft::make_device_vector<int64_t, int64_t>(res, n_rows);
* raft::make_blobs(res, dataset.view(), labels.view());
* raft::random::make_blobs(res, dataset.view(), labels.view());
*
* // create a brute_force knn index from the dataset
* auto index = raft::neighbors::brute_force::build(res,
Expand Down
28 changes: 14 additions & 14 deletions docs/source/vector_search_tutorial.md
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,10 @@ raft::device_resources res;
int n_rows = 10000;
int n_cols = 10000;

auto dataset = raft::make_device_matrix<float, int>(res, n_rows, n_cols);
auto labels = raft::make_device_vector<float, int>(res, n_rows);
auto dataset = raft::make_device_matrix<float, int64_t>(res, n_rows, n_cols);
auto labels = raft::make_device_vector<int64_t, int64_t>(res, n_rows);

raft::make_blobs(res, dataset.view(), labels.view());
raft::random::make_blobs(res, dataset.view(), labels.view());
```
That's it. We've now generated a random 10kx10k matrix with points that cleanly separate into Gaussian clusters, along with a vector of cluster labels for each of the data points. Notice the `cuh` extension in the header file include for `make_blobs`. This signifies to us that this file contains CUDA device functions like kernel code so the CUDA compiler, `nvcc` is needed in order to compile any code that uses it. Generally, any source files that include headers with a `cuh` extension use the `.cu` extension instead of `.cpp`. The rule here is that `cpp` source files contain code which can be compiled with a C++ compiler like `g++` while `cu` files require the CUDA compiler.
Expand Down Expand Up @@ -125,14 +125,14 @@ auto search = raft::make_const_mdspan(dataset.view());

// Indices and Distances are of dimensions (n, k)
// where n is number of rows in the search matrix
auto reference_indices = raft::make_device_matrix<int, int>(search.extent(0), k); // stores index of neighbors
auto reference_distances = raft::make_device_matrix<float, int>(search.extent(0), k); // stores distance to neighbors
auto reference_indices = raft::make_device_matrix<int, int64_t>(res, search.extent(0), k); // stores index of neighbors
auto reference_distances = raft::make_device_matrix<float, int64_t>(res, search.extent(0), k); // stores distance to neighbors

raft::neighbors::brute_force::search(res,
bfknn_index,
search,
raft::make_const_mdspan(indices.view()),
raft::make_const_mdspan(distances.view()));
reference_indices.view(),
reference_distances.view());
```
We have established several things here by building a flat index. Now we know the exact 64 neighbors of all points in the matrix, and this algorithm can be generally useful in several ways:
Expand All @@ -152,9 +152,9 @@ Next we'll train an ANN index. We'll use our graph-based CAGRA algorithm for thi
raft::device_resources res;
// use default index parameters
cagra::index_params index_params;
raft::neighbors::cagra::index_params index_params;
auto index = cagra::build<float, uint32_t>(res, index_params, dataset);
auto index = raft::neighbors::cagra::build<float, uint32_t>(res, index_params, raft::make_const_mdspan(dataset.view()));
```

### Query the CAGRA index
Expand All @@ -167,10 +167,10 @@ auto indices = raft::make_device_matrix<uint32_t>(res, n_rows, k);
auto distances = raft::make_device_matrix<float>(res, n_rows, k);

// use default search parameters
cagra::search_params search_params;
raft::neighbors::cagra::search_params search_params;

// search K nearest neighbors
cagra::search<float, uint32_t>(
raft::neighbors::cagra::search<float, uint32_t>(
res, search_params, index, search, indices.view(), distances.view());
```

Expand All @@ -197,8 +197,8 @@ raft::stats::neighborhood_recall(res,
raft::make_const_mdspan(indices.view()),
raft::make_const_mdspan(reference_indices.view()),
recall_value.view(),
raft::make_const_mdspan(distances),
raft::make_const_mdspan(reference_distances));
raft::make_const_mdspan(distances.view()),
raft::make_const_mdspan(reference_distances.view()));

res.sync_stream();
```
Expand Down Expand Up @@ -340,4 +340,4 @@ The below example specifies the total number of bytes that RAFT can use for temp

std::shared_ptr<rmm::mr::managed_memory_resource> managed_resource;
raft::device_resource res(managed_resource, std::make_optional<std::size_t>(3 * 1024^3));
```
```

0 comments on commit 04fa426

Please sign in to comment.