Skip to content

Commit

Permalink
Add col-major support for brute force knn (#217)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Tarang Jain (https://github.com/tarang-jain)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #217
  • Loading branch information
benfred authored Jul 17, 2024
1 parent 27f816c commit 90be484
Show file tree
Hide file tree
Showing 5 changed files with 364 additions and 42 deletions.
72 changes: 64 additions & 8 deletions cpp/include/cuvs/neighbors/brute_force.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ struct index : cuvs::neighbors::index {
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* This index will store a non-owning reference to the dataset.
* This index will copy the host dataset onto the device, and take ownership of any
* precaculated norms.
*/
index(raft::resources const& res,
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset_view,
Expand All @@ -61,7 +62,8 @@ struct index : cuvs::neighbors::index {
*
* Constructs a brute force index from a dataset. This lets us precompute norms for
* the dataset, providing a speed benefit over doing this at query time.
* The dataset will be copied to the device and the index will own the device memory.
* This index will store a non-owning reference to the dataset, but will move
* any norms supplied.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset_view,
Expand All @@ -71,7 +73,7 @@ struct index : cuvs::neighbors::index {

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* This class stores a non-owning reference to the dataset and norms.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
Expand All @@ -80,6 +82,17 @@ struct index : cuvs::neighbors::index {
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms, with
* the dataset being supplied on device in a col_major format
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::col_major> dataset_view,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg = 0.0);

/**
* Replace the dataset with a new dataset.
*/
Expand Down Expand Up @@ -152,12 +165,34 @@ struct index : cuvs::neighbors::index {
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed ivf-flat index
* @return the constructed bruteforce index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::row_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float>;

/**
* @brief Build the index from the dataset for efficient search.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // create and fill the index from a [N, D] dataset
* auto index = brute_force::build(handle, dataset, metric);
* @endcode
*
* @param[in] handle
* @param[in] dataset a device pointer to a col-major matrix [n_rows, dim]
* @param[in] metric cuvs::distance::DistanceType
* @param[in] metric_arg metric argument
*
* @return the constructed bruteforce index
*/
auto build(raft::resources const& handle,
raft::device_matrix_view<const float, int64_t, raft::col_major> dataset,
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Unexpanded,
float metric_arg = 0) -> cuvs::neighbors::brute_force::index<float>;
/**
* @}
*/
Expand All @@ -169,7 +204,7 @@ auto build(raft::resources const& handle,
/**
* @brief Search ANN using the constructed index.
*
* See the [ivf_flat::build](#ivf_flat::build) documentation for a usage example.
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* Note, this function requires a temporary buffer to store intermediate results between cuda kernel
* calls, which may lead to undesirable allocations and slowdown. To alleviate the problem, you can
Expand All @@ -186,20 +221,41 @@ auto build(raft::resources const& handle,
* @endcode
*
* @param[in] handle
* @param[in] index ivf-flat constructed index
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a row-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter a optional device bitmap filter function that greenlights samples for a
* given
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
raft::device_matrix_view<const float, int64_t, raft::row_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter);

/**
* @brief Search ANN using the constructed index.
*
* See the [brute_force::build](#brute_force::build) documentation for a usage example.
*
* @param[in] handle
* @param[in] index bruteforce constructed index
* @param[in] queries a device pointer to a col-major matrix [n_queries, index->dim()]
* @param[out] neighbors a device pointer to the indices of the neighbors in the source dataset
* [n_queries, k]
* @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k]
* @param[in] sample_filter an optional device bitmap filter function that greenlights samples for a
* given query
*/
void search(raft::resources const& handle,
const cuvs::neighbors::brute_force::index<float>& index,
raft::device_matrix_view<const float, int64_t, raft::col_major> queries,
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<float, int64_t, raft::row_major> distances,
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter);
/**
* @}
*/
Expand Down
52 changes: 52 additions & 0 deletions cpp/src/neighbors/brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,36 @@ index<T>::index(raft::resources const& res,
{
}

template <typename T>
index<T>::index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::col_major> dataset_view,
std::optional<raft::device_vector<T, int64_t>>&& norms,
cuvs::distance::DistanceType metric,
T metric_arg)
: cuvs::neighbors::index(),
metric_(metric),
dataset_(
raft::make_device_matrix<T, int64_t>(res, dataset_view.extent(0), dataset_view.extent(1))),
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
// currently we don't support col_major inside tiled_brute_force_knn, because
// of limitations of the pairwise_distance API:
// 1) paiwise_distance takes a single 'isRowMajor' parameter - and we have
// multiple options here (both dataset and queries)
// 2) because of tiling, we need to be able to set a custom stride in the PW
// api, which isn't supported
// Instead, transpose the input matrices if they are passed as col-major.
// (note: we're doing the transpose here to avoid doing per query)
raft::linalg::transpose(res,
const_cast<T*>(dataset_view.data_handle()),
dataset_.data_handle(),
dataset_view.extent(0),
dataset_view.extent(1),
raft::resource::get_cuda_stream(res));
dataset_view_ = raft::make_const_mdspan(dataset_.view());
}

template <typename T>
void index<T>::update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset)
Expand All @@ -93,6 +123,14 @@ void index<T>::update_dataset(raft::resources const& res,
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
auto build(raft::resources const& res, \
raft::device_matrix_view<const T, int64_t, raft::col_major> dataset, \
cuvs::distance::DistanceType metric, \
T metric_arg) \
->cuvs::neighbors::brute_force::index<T> \
{ \
return detail::build<T>(res, dataset, metric, metric_arg); \
} \
\
void search( \
Expand All @@ -109,6 +147,20 @@ void index<T>::update_dataset(raft::resources const& res,
detail::brute_force_search_filtered<T, int64_t>( \
res, idx, queries, *sample_filter, neighbors, distances); \
} \
} \
void search( \
raft::resources const& res, \
const cuvs::neighbors::brute_force::index<T>& idx, \
raft::device_matrix_view<const T, int64_t, raft::col_major> queries, \
raft::device_matrix_view<int64_t, int64_t, raft::row_major> neighbors, \
raft::device_matrix_view<T, int64_t, raft::row_major> distances, \
std::optional<cuvs::core::bitmap_view<const uint32_t, int64_t>> sample_filter = std::nullopt) \
{ \
if (!sample_filter.has_value()) { \
detail::brute_force_search<T, int64_t>(res, idx, queries, neighbors, distances); \
} else { \
RAFT_FAIL("filtered search isn't available with col_major queries yet"); \
} \
} \
\
template struct cuvs::neighbors::brute_force::index<T>;
Expand Down
27 changes: 8 additions & 19 deletions cpp/src/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -513,11 +513,11 @@ void brute_force_knn_impl(
if (translations == nullptr) delete id_ranges;
};

template <typename T, typename IdxT>
template <typename T, typename IdxT, typename QueryLayoutT = raft::row_major>
void brute_force_search(
raft::resources const& res,
const cuvs::neighbors::brute_force::index<T>& idx,
raft::device_matrix_view<const T, int64_t, raft::row_major> queries,
raft::device_matrix_view<const T, int64_t, QueryLayoutT> queries,
raft::device_matrix_view<IdxT, int64_t, raft::row_major> neighbors,
raft::device_matrix_view<T, int64_t, raft::row_major> distances,
std::optional<raft::device_vector_view<const T, int64_t>> query_norms = std::nullopt)
Expand All @@ -544,7 +544,7 @@ void brute_force_search(
distances.data_handle(),
k,
true,
true,
std::is_same_v<QueryLayoutT, raft::row_major>,
nullptr,
idx.metric(),
idx.metric_arg(),
Expand Down Expand Up @@ -719,43 +719,32 @@ void brute_force_search_filtered(
return;
}

template <typename T>
template <typename T, typename LayoutT = raft::row_major>
cuvs::neighbors::brute_force::index<T> build(
raft::resources const& res,
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset,
raft::device_matrix_view<const T, int64_t, LayoutT> dataset,
cuvs::distance::DistanceType metric,
T metric_arg)
{
// certain distance metrics can benefit by pre-calculating the norms for the index dataset
// which lets us avoid calculating these at query time
std::optional<raft::device_vector<T, int64_t>> norms;
auto dataset_storage = std::optional<raft::device_matrix<T, int64_t>>{};
auto dataset_view = [&res, &dataset_storage, dataset]() {
if constexpr (std::is_same_v<decltype(dataset),
raft::device_matrix_view<const T, int64_t, raft::row_major>>) {
return dataset;
} else {
dataset_storage =
raft::make_device_matrix<T, int64_t>(res, dataset.extent(0), dataset.extent(1));
raft::copy(res, dataset_storage->view(), dataset);
return raft::make_const_mdspan(dataset_storage->view());
}
}();

if (metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
metric == cuvs::distance::DistanceType::CosineExpanded) {
norms = raft::make_device_vector<T, int64_t>(res, dataset.extent(0));
// cosine needs the l2norm, where as l2 distances needs the squared norm
if (metric == cuvs::distance::DistanceType::CosineExpanded) {
raft::linalg::norm(res,
dataset_view,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
dataset_view,
dataset,
norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
Expand Down
Loading

0 comments on commit 90be484

Please sign in to comment.