Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for iterating over batches in bfknn #1947

Merged
merged 14 commits into from
Nov 17, 2023
5 changes: 3 additions & 2 deletions cpp/bench/ann/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ if(RAFT_ANN_BENCH_USE_RAFT_IVF_PQ
OR RAFT_ANN_BENCH_USE_RAFT_BRUTE_FORCE
OR RAFT_ANN_BENCH_USE_RAFT_IVF_FLAT
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
OR RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB
)
set(RAFT_ANN_BENCH_USE_RAFT ON)
endif()
Expand Down Expand Up @@ -263,7 +263,8 @@ if(RAFT_ANN_BENCH_USE_RAFT_CAGRA_HNSWLIB)
${CMAKE_CURRENT_BINARY_DIR}/_deps/hnswlib-src/hnswlib
LINKS
raft::compiled
CXXFLAGS "${HNSW_CXX_FLAGS}"
CXXFLAGS
"${HNSW_CXX_FLAGS}"
)
endif()

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct l2_exp_cutlass_op {

__device__ l2_exp_cutlass_op() noexcept : sqrt(false) {}
__device__ l2_exp_cutlass_op(bool isSqrt) noexcept : sqrt(isSqrt) {}
__device__ AccT operator()(DataT& aNorm, const DataT& bNorm, DataT& accVal) const noexcept
inline __device__ AccT operator()(DataT aNorm, DataT bNorm, DataT accVal) const noexcept
{
AccT outVal = aNorm + bNorm - DataT(2.0) * accVal;

Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/brute_force-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances) RAFT_EXPLICIT;
raft::device_matrix_view<T, int64_t, row_major> distances) RAFT_EXPLICIT;

template <typename idx_t,
typename value_t,
Expand Down
31 changes: 2 additions & 29 deletions cpp/include/raft/neighbors/brute_force-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -346,36 +346,9 @@ void search(raft::resources const& res,
const index<T>& idx,
raft::device_matrix_view<const T, int64_t, row_major> queries,
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors,
raft::device_matrix_view<float, int64_t, row_major> distances)
raft::device_matrix_view<T, int64_t, row_major> distances)
{
RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Value of k must match for outputs");
RAFT_EXPECTS(idx.dataset().extent(1) == queries.extent(1),
"Number of columns in queries must match brute force index");

auto k = neighbors.extent(1);
auto d = idx.dataset().extent(1);

std::vector<T*> dataset = {const_cast<T*>(idx.dataset().data_handle())};
std::vector<int64_t> sizes = {idx.dataset().extent(0)};
std::vector<T*> norms;
if (idx.has_norms()) { norms.push_back(const_cast<T*>(idx.norms().data_handle())); }

detail::brute_force_knn_impl<int64_t, IdxT, T>(res,
dataset,
sizes,
d,
const_cast<T*>(queries.data_handle()),
queries.extent(0),
neighbors.data_handle(),
distances.data_handle(),
k,
true,
true,
nullptr,
idx.metric(),
idx.metric_arg(),
raft::identity_op(),
norms.size() ? &norms : nullptr);
raft::neighbors::detail::brute_force_search<T, IdxT>(res, idx, queries, neighbors, distances);
}
/** @} */ // end group brute_force_knn
} // namespace raft::neighbors::brute_force
68 changes: 68 additions & 0 deletions cpp/include/raft/neighbors/brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/
#pragma once
#include <memory>

#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY
#include "brute_force-inl.cuh"
Expand All @@ -22,3 +23,70 @@
#ifdef RAFT_COMPILED
#include "brute_force-ext.cuh"
#endif

#include <raft/neighbors/detail/knn_brute_force_batch_k_query.cuh>

namespace raft::neighbors::brute_force {
/**
* @brief Make a brute force query over batches of k
*
* This lets you query for batches of k. For example, you can get
* the first 100 neighbors, then the next 100 neighbors etc.
*
* Example usage:
* @code{.cpp}
* #include <raft/neighbors/brute_force.cuh>
* #include <raft/core/device_mdarray.hpp>
* #include <raft/random/make_blobs.cuh>

* // create a random dataset
* int n_rows = 10000;
* int n_cols = 10000;

* raft::device_resources res;
* auto dataset = raft::make_device_matrix<float, int>(res, n_rows, n_cols);
* auto labels = raft::make_device_vector<float, int>(res, n_rows);

* raft::make_blobs(res, dataset.view(), labels.view());
*
* // create a brute_force knn index from the dataset
* auto index = raft::neighbors::brute_force::build(res,
* raft::make_const_mdspan(dataset.view()));
*
* // search the index in batches of 128 nearest neighbors
* auto search = raft::make_const_mdspan(dataset.view());
* auto query = make_batch_k_query<float, int>(res, index, search, 128);
* for (auto & batch: *query) {
* // batch.indices() and batch.distances() contain the information on the current batch
* }
*
* // we can also support variable sized batches - loaded up a different number
* // of neighbors at each iteration through the ::advance method
* int64_t batch_size = 128;
* query = make_batch_k_query<float, int>(res, index, search, batch_size);
* for (auto it = query->begin(); it != query->end(); it.advance(batch_size)) {
* // batch.indices() and batch.distances() contain the information on the current batch
*
* batch_size += 16; // load up an extra 16 items in the next batch
* }
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
* @param[in] res
* @param[in] index The index to query
* @param[in] query A device matrix view to query for [n_queries, index->dim()]
* @param[in] batch_size The size of each batch
*/

template <typename T, typename IdxT>
std::shared_ptr<batch_k_query<T, IdxT>> make_batch_k_query(
const raft::resources& res,
const raft::neighbors::brute_force::index<T>& index,
raft::device_matrix_view<const T, int64_t, row_major> query,
int64_t batch_size)
{
return std::shared_ptr<batch_k_query<T, IdxT>>(
new detail::gpu_batch_k_query<T, IdxT>(res, index, query, batch_size));
}
} // namespace raft::neighbors::brute_force
119 changes: 118 additions & 1 deletion cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/neighbors/neighbors_types.hpp>

#include <raft/core/logger.hpp>

Expand Down Expand Up @@ -69,7 +70,7 @@ struct index : ann::index {
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
/** Whether or not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }
Expand Down Expand Up @@ -160,6 +161,122 @@ struct index : ann::index {
T metric_arg_;
};

/**
* @brief Interface for performing queries over values of k
*
* This interface lets you iterate over batches of k from a brute_force::index.
* This lets you do things like retrieve the first 100 neighbors for a query,
* apply post processing to remove any unwanted items and then if needed get the
* next 100 closest neighbors for the query.
*
* This query interface exposes C++ iterators through the ::begin and ::end, and
* is compatible with range based for loops.
*
* Note that this class is an abstract class without any cuda dependencies, meaning
* that it doesn't require a cuda compiler to use - but also means it can't be directly
* instantiated. See the raft::neighbors::brute_force::make_batch_k_query
* function for usage examples.
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*/
template <typename T, typename IdxT = int64_t>
class batch_k_query {
public:
batch_k_query(const raft::resources& res,
int64_t index_size,
int64_t query_size,
int64_t batch_size)
: res(res), index_size(index_size), query_size(query_size), batch_size(batch_size)
{
}
virtual ~batch_k_query() {}

using value_type = raft::neighbors::batch<T, IdxT>;

class iterator {
public:
using value_type = raft::neighbors::batch<T, IdxT>;
using reference = const value_type&;
using pointer = const value_type*;

iterator(const batch_k_query<T, IdxT>* query, int64_t offset = 0)
: current(query->res, 0, 0), batches(query->res, 0, 0), query(query), offset(offset)
{
query->load_batch(offset, query->batch_size, &batches);
query->slice_batch(batches, offset, query->batch_size, &current);
}

reference operator*() const { return current; }

pointer operator->() const { return &current; }

iterator& operator++()
{
advance(query->batch_size);
return *this;
}

iterator operator++(int)
{
iterator previous(*this);
operator++();
return previous;
}

/**
* @brief Advance the iterator, using a custom size for the next batch
*
* Using operator++ means that we will load up the same batch_size for each
* batch. This method allows us to get around this restriction, and load up
* arbitrary batch sizes on each iteration.
* See raft::neighbors::brute_force::make_batch_k_query for a usage example.
*
* @param[in] next_batch_size: size of the next batch to load up
*/
void advance(int64_t next_batch_size)
{
offset = std::min(offset + current.batch_size(), query->index_size);
if (offset + next_batch_size > batches.batch_size()) {
query->load_batch(offset, next_batch_size, &batches);
}
query->slice_batch(batches, offset, next_batch_size, &current);
}

friend bool operator==(const iterator& lhs, const iterator& rhs)
{
return (lhs.query == rhs.query) && (lhs.offset == rhs.offset);
};
friend bool operator!=(const iterator& lhs, const iterator& rhs) { return !(lhs == rhs); };

protected:
// the current batch of data
value_type current;

// the currently loaded group of data (containing multiple batches of data that we can iterate
// through)
value_type batches;

const batch_k_query<T, IdxT>* query;
int64_t offset, current_batch_size;
};

iterator begin() const { return iterator(this); }
iterator end() const { return iterator(this, index_size); }

protected:
// these two methods need cuda code, and are implemented in the subclass
virtual void load_batch(int64_t offset,
int64_t next_batch_size,
batch<T, IdxT>* output) const = 0;
virtual void slice_batch(const value_type& input,
int64_t offset,
int64_t batch_size,
value_type* output) const = 0;

const raft::resources& res;
int64_t index_size, query_size, batch_size;
};
/** @} */

} // namespace raft::neighbors::brute_force
Loading