Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for iterating over batches in bfknn #1947

Merged
merged 14 commits into from
Nov 17, 2023
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct l2_exp_cutlass_op {

__device__ l2_exp_cutlass_op() noexcept : sqrt(false) {}
__device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;

Expand Down
10 changes: 7 additions & 3 deletions cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,9 @@ void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances) RAFT_EXPLICIT;
raft::device_matrix_view<T, int64_t, row_major> distances,
std::optional<raft::device_vector_view<const T, int64_t>> query_norms = std::nullopt)
RAFT_EXPLICIT;

template <typename idx_t,
typename value_t,
Expand Down Expand Up @@ -114,14 +116,16 @@ extern template void search<float, int>(
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);
raft::device_matrix_view<float, int64_t, row_major> distances,
std::optional<raft::device_vector_view<const float, int64_t>> query_norms);

extern template void search<float, int64_t>(
raft::resources const& res,
const raft::neighbors::brute_force::index<float>& idx,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<int64_t, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances);
raft::device_matrix_view<float, int64_t, row_major> distances,
std::optional<raft::device_vector_view<const float, int64_t>> query_norms);

extern template raft::neighbors::brute_force::index<float> build<float>(
raft::resources const& res,
Expand Down
38 changes: 21 additions & 17 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -340,13 +340,15 @@ index<T> build(raft::resources const& res,
* [n_queries, k]
* @param[out] distances a device matrix view to the distances to the selected neighbors [n_queries,
* k]
* @param[in] query_norms Optional device_vector_view of precomputed query norms
*/
template <typename T, typename IdxT>
void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
raft::device_matrix_view<T, int64_t, row_major> distances,
std::optional<raft::device_vector_view<const T, int64_t>> query_norms = std::nullopt)
benfred marked this conversation as resolved.
Show resolved Hide resolved
{
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
Expand All @@ -360,22 +362,24 @@ void search(raft::resources const& res,
std::vector<T*> norms;
if (idx.has_norms()) { norms.push_back(const_cast<T*>(idx.norms().data_handle())); }

detail::brute_force_knn_impl<int64_t, IdxT, T>(res,
dataset,
sizes,
d,
const_cast<T*>(queries.data_handle()),
queries.extent(0),
neighbors.data_handle(),
distances.data_handle(),
k,
true,
true,
nullptr,
idx.metric(),
idx.metric_arg(),
raft::identity_op(),
norms.size() ? &norms : nullptr);
detail::brute_force_knn_impl<int64_t, IdxT, T>(
res,
dataset,
sizes,
d,
const_cast<T*>(queries.data_handle()),
queries.extent(0),
neighbors.data_handle(),
distances.data_handle(),
k,
true,
true,
nullptr,
idx.metric(),
idx.metric_arg(),
raft::identity_op(),
norms.size() ? &norms : nullptr,
query_norms ? query_norms->data_handle() : nullptr);
}
/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
240 changes: 240 additions & 0 deletions cpp/include/raft/neighbors/brute_force_batch_k_query.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
benfred marked this conversation as resolved.
Show resolved Hide resolved

#include <raft/core/device_mdarray.hpp>
#include <raft/core/error.hpp>
#include <raft/core/host_mdarray.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/norm.cuh>
#include <raft/matrix/slice.cuh>
#include <raft/neighbors/brute_force.cuh>

#include <raft/core/logger.hpp>

namespace raft::neighbors::brute_force {
/**
* @addtogroup brute_force
benfred marked this conversation as resolved.
Show resolved Hide resolved
* @{
*/

/**
* @brief Brute force query over batches of k
*
* This class lets you query for batches of k. For example, you can get
* the first 100 neighbors, then the next 100 neighbors etc.
*
* Example usage:
* @code{.cpp}
* // create a brute_force knn index
* raft::device_resources res;
* auto index = raft::neighbors::brute_force::build(res,
* raft::make_const_mdspan(dataset.view()));
*
* // search the index in batches of 128 nearest neighbors
* auto search = raft::make_const_mdspan(dataset.view());
* for (auto & batch: batch_k_query(res, index, query, 128)) {
* // batch.indices() and batch.distances() contain the information on the current batch
* }
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*/
template <typename T, typename IdxT = int64_t>
class batch_k_query {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please define the interface in a file called raft/neighbors/neighbors_types.hpp (for generalized / non-brute-force types) and/or raft/neighbors/brute_force_types.hpp, define the implementation in raft::neighbors::detail namespace, and then create a stateless factory function like raft::neighbors::brute_force::search_batch_k() to create an instance of it.

public:
/**
* @brief Construct a brute force batch k query object
*
* Constructs the batch_k_query - which lets you iterate over batches of
* nearest neighbors.
*
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need copy/paste usage examples for all public API functions.

* @param[in] res
* @param[in] index The index to query
* @param[in] query A device matrix view to query for [n_queries, index->dim()]
* @param[in] batch_size The size of each batch
*/
batch_k_query(const raft::resources& res,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We've structured RAFT's computational APIs as stateless functions that operate on objects which hold trivial state. I'm not totally against having an iterator that knows how to load the next batch, but to be consistent with other APIs, we should have a stateless public API function that creates an iterator instance and then arm the iterator implementation with whatever state it's needed to iterate from beginning to end.

This way, the user can call something like:

auto iter = raft::neighbors::brute_force::make_batch_k_iter(...);

for( ...) {
   
}

Then the batched iterator itself is a mere implementation detail.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done in the last commit - the batch_k_query is now an abstract class containing no cuda code, and there is a detail::gpu_batch_k_query class with the cuda definitions that is created via a make_batch_k_query factory function

const raft::neighbors::brute_force::index<T>& index,
raft::device_matrix_view<const T, int64_t, row_major> query,
int64_t batch_size)
: res(res), index(index), query(query), batch_size(batch_size)
{
auto metric = index.metric();

// precompute query norms, and re-use across batches
if (metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::CosineExpanded) {
query_norms = make_device_vector<T, int64_t>(res, query.extent(0));

if (metric == raft::distance::DistanceType::CosineExpanded) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this conditional in a few places in the code- perhaps we should consolidate them into a function in detail.

raft::linalg::norm(res,
benfred marked this conversation as resolved.
Show resolved Hide resolved
query,
query_norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS,
raft::sqrt_op{});
} else {
raft::linalg::norm(res,
query,
query_norms->view(),
raft::linalg::NormType::L2Norm,
raft::linalg::Apply::ALONG_ROWS);
}
}
}

/** a single batch of nearest neighbors in device memory */
class batch {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This batch class could be re-used for batching IVF and CAGRA searches. Can it be implemented in a separate file?

benfred marked this conversation as resolved.
Show resolved Hide resolved
public:
/** Create a new empty batch of data */
batch(raft::resources const& res, int64_t rows, int64_t cols)
: indices_(make_device_matrix<IdxT, int64_t>(res, rows, cols)),
distances_(make_device_matrix<T, int64_t>(res, rows, cols)),
start_k(0),
end_k(cols)
{
}

/** Returns the indices for the batch */
device_matrix_view<const IdxT, int64_t> indices() const
{
return raft::make_const_mdspan(indices_.view());
}

/** Returns the distances for the batch */
device_matrix_view<const T, int64_t> distances() const
{
return raft::make_const_mdspan(distances_.view());
}

benfred marked this conversation as resolved.
Show resolved Hide resolved
friend class iterator;
benfred marked this conversation as resolved.
Show resolved Hide resolved
friend class batch_k_query<T, IdxT>;

protected:
raft::device_matrix<IdxT, int64_t> indices_;
raft::device_matrix<T, int64_t> distances_;
int64_t start_k;
int64_t end_k;
};

class iterator {
benfred marked this conversation as resolved.
Show resolved Hide resolved
public:
using value_type = batch;
using reference = const value_type&;
using pointer = const value_type*;

iterator(const batch_k_query<T, IdxT>* query, int64_t offset = 0)
: current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset)
{
load_batches(query->batch_size);
slice_current_batch(offset, query->batch_size);
}

reference operator*() const { return current; }

pointer operator->() const { return &current; }

iterator& operator++()
{
advance(query->batch_size);
return *this;
}

iterator operator++(int)
{
iterator previous(*this);
operator++();
return previous;
}

void advance(int64_t next_batch_size)
benfred marked this conversation as resolved.
Show resolved Hide resolved
{
offset = std::min(offset + current.indices().extent(1), query->index.size());
if (offset + next_batch_size > current_batch_size) { load_batches(next_batch_size); }
slice_current_batch(offset, next_batch_size);
}

friend bool operator==(const iterator& lhs, const iterator& rhs)
{
return (lhs.query == rhs.query) && (lhs.offset == rhs.offset);
};
friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); };

protected:
void load_batches(int64_t next_batch_size)
{
if (offset >= query->index.size()) { return; }

// we're aiming to load multiple batches here - since we don't know the max iteration
// grow the size we're loading exponentially
int64_t batch_size = std::min(std::max(offset * 2, next_batch_size * 2), query->index.size());
batches = batch(query->res, query->query.extent(0), batch_size);
query->load_batch(batches);
current_batch_size = batch_size;
}

void slice_current_batch(int64_t offset, int64_t batch_size)
{
auto num_queries = batches.indices_.extent(0);
batch_size = std::min(batch_size, query->index.size() - offset);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So the batch size changes as we iterate through k?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The batch sizes can change as we iterate through K - if we want to support the redis VecSimBatchIterator interface, we will need this to handle the getNextResults method https://github.com/RedisAI/VectorSimilarity/blob/22954489d9184c9eba55f477463439a3532ca04e/src/VecSim/batch_iterator.h#L40-L42

current = batch(query->res, num_queries, batch_size);

if (!num_queries || !batch_size) { return; }

matrix::slice_coordinates<int64_t> coords{0, offset, num_queries, offset + batch_size};
matrix::slice(query->res, batches.indices(), current.indices_.view(), coords);
matrix::slice(query->res, batches.distances(), current.distances_.view(), coords);
}

// the current batch of data
batch current;

// the currently loaded group of data (containing multiple batches of data that we can iterate
// through)
batch batches;

const batch_k_query<T, IdxT>* query;
int64_t offset, current_batch_size;
};

iterator begin() const { return iterator(this); }
iterator end() const { return iterator(this, index.size()); }

protected:
void load_batch(batch& output) const
{
std::optional<raft::device_vector_view<const float, int64_t>> query_norms_view;
if (query_norms) { query_norms_view = query_norms->view(); }

brute_force::search<T, IdxT>(
res, index, query, output.indices_.view(), output.distances_.view(), query_norms_view);
}

const raft::resources& res;
benfred marked this conversation as resolved.
Show resolved Hide resolved
const raft::neighbors::brute_force::index<T>& index;
raft::device_matrix_view<const T, int64_t, row_major> query;
int64_t batch_size;
std::optional<device_vector<T, int64_t>> query_norms;
};

/** @} */
} // namespace raft::neighbors::brute_force
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ struct index : ann::index {
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
/** Whether or not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }
Expand Down
Loading