Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branch 23.12 merge 23.10 #1873

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 45 additions & 7 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#include <common/benchmark.hpp>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>
#include <thrust/sequence.h>

#include <optional>

Expand All @@ -40,6 +42,8 @@ struct params {
int block_size;
int search_width;
int max_iterations;
/** Ratio of removed indices. */
double removed_ratio;
};

template <typename T, typename IdxT>
Expand All @@ -49,7 +53,8 @@ struct CagraBench : public fixture {
params_(ps),
queries_(make_device_matrix<T, int64_t>(handle, ps.n_queries, ps.n_dims)),
dataset_(make_device_matrix<T, int64_t>(handle, ps.n_samples, ps.n_dims)),
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree))
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree)),
removed_indices_bitset_(handle, ps.n_samples)
{
// Generate random dataset and queriees
raft::random::RngState state{42};
Expand All @@ -74,6 +79,13 @@ struct CagraBench : public fixture {

auto metric = raft::distance::DistanceType::L2Expanded;

auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
index_.emplace(raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
}
Expand All @@ -95,10 +107,18 @@ struct CagraBench : public fixture {
distances.data_handle(), params_.n_queries, params_.k);

auto queries_v = make_const_mdspan(queries_.view());
loop_on_state(state, [&]() {
raft::neighbors::cagra::search(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
});
if (params_.removed_ratio > 0) {
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());
loop_on_state(state, [&]() {
raft::neighbors::cagra::search_with_filtering(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter);
});
} else {
loop_on_state(state, [&]() {
raft::neighbors::cagra::search(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
});
}

double data_size = params_.n_samples * params_.n_dims * sizeof(T);
double graph_size = params_.n_samples * params_.degree * sizeof(IdxT);
Expand All @@ -120,6 +140,7 @@ struct CagraBench : public fixture {
state.counters["block_size"] = params_.block_size;
state.counters["search_width"] = params_.search_width;
state.counters["iterations"] = iterations;
state.counters["removed_ratio"] = params_.removed_ratio;
}

private:
Expand All @@ -128,6 +149,7 @@ struct CagraBench : public fixture {
raft::device_matrix<T, int64_t, row_major> queries_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> knn_graph_;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
};

inline const std::vector<params> generate_inputs()
Expand All @@ -141,7 +163,8 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{0}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
auto inputs2 = raft::util::itertools::product<params>({2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
Expand All @@ -151,7 +174,22 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{64, 128, 256, 512, 1024}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 = raft::util::itertools::product<params>(
{2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
{1, 10, 10000}, // n_queries
{255}, // k
{64}, // knn graph degree
{300}, // itopk_size
{256}, // block_size
{2}, // search_width
{0}, // max_iterations
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
return inputs;
Expand Down
25 changes: 23 additions & 2 deletions cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ struct index : ann::index {
/** Dataset norms */
[[nodiscard]] inline auto norms() const -> device_vector_view<const T, int64_t, row_major>
{
return make_const_mdspan(norms_.value().view());
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); }
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }

Expand Down Expand Up @@ -102,10 +102,30 @@ struct index : ann::index {
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); }
update_dataset(res, dataset);
resource::sync_stream(res);
}

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset_view,
std::optional<raft::device_vector_view<const T, int64_t>> norms_view,
raft::distance::DistanceType metric,
T metric_arg = 0.0)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_view_(dataset_view),
norms_view_(norms_view),
metric_arg_(metric_arg)
{
}

private:
/**
* Replace the dataset with a new dataset.
Expand Down Expand Up @@ -135,6 +155,7 @@ struct index : ann::index {
raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
std::optional<raft::device_vector_view<const T, int64_t>> norms_view_;
raft::device_matrix_view<const T, int64_t, row_major> dataset_view_;
T metric_arg_;
};
Expand Down
20 changes: 19 additions & 1 deletion cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,25 @@ void search(raft::resources const& res,
/**
* @brief Search ANN using the constructed index with the given sample filter.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = cagra::build(res, index_params, dataset);
* // use default search parameters
* cagra::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<IdxT>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances,
* filtering::bitset_filter(removed_indices_bitset.view()));
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
Expand Down
23 changes: 16 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,15 @@ __global__ void apply_filter_kernel(INDEX_T* const result_indices_ptr,
const INDEX_T query_id_offset,
SAMPLE_FILTER_T sample_filter)
{
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= result_buffer_size * num_queries) { return; }
const auto i = tid % result_buffer_size;
const auto j = tid / result_buffer_size;
const auto index = i + j * lds;

if (!sample_filter(query_id_offset + j, result_indices_ptr[index])) {
if (result_indices_ptr[index] != ~index_msb_1_mask &&
!sample_filter(query_id_offset + j, result_indices_ptr[index])) {
result_indices_ptr[index] = utils::get_max_value<INDEX_T>();
result_distances_ptr[index] = utils::get_max_value<DISTANCE_T>();
}
Expand Down Expand Up @@ -788,12 +790,15 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size;
auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size;

// Remove parent bit in search results
remove_parent_bit(
num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream);
if constexpr (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
// Remove parent bit in search results
remove_parent_bit(num_queries,
result_buffer_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
stream);

if (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
apply_filter<INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>(
result_indices.data() + (iter & 0x1) * itopk_size,
result_distances.data() + (iter & 0x1) * itopk_size,
Expand Down Expand Up @@ -821,6 +826,10 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
true,
topk_hint.data(),
stream);
} else {
// Remove parent bit in search results
remove_parent_bit(
num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream);
}

// Copy results from working buffer to final buffer
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ struct search_plan_impl : public search_plan_impl_base {
"`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " +
std::to_string(hashmap_max_fill_rate) + " has been given.";
}
if constexpr (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
if (hashmap_mode == hash_mode::SMALL) {
error_message += "`SMALL` hash is not available when filtering";
} else {
hashmap_mode = hash_mode::HASH;
}
}
if (algo == search_algo::MULTI_CTA) {
if (hashmap_mode == hash_mode::SMALL) {
error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\"";
Expand Down
3 changes: 1 addition & 2 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out

std::thread update_and_sample_thread(update_and_sample, it);

std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r";
std::fflush(stdout);
RAFT_LOG_DEBUG("# GNND iteraton: %lu / %lu", it + 1, build_config_.max_iterations);

// Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it
// contains some information for local_join.
Expand Down
48 changes: 48 additions & 0 deletions cpp/include/raft/neighbors/sample_filter.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstddef>
#include <cstdint>

#include <raft/core/bitset.cuh>

namespace raft::neighbors::filtering {
/**
* @brief Filter an index with a bitset
*
* @tparam index_t Indexing type
*/
template <typename bitset_t, typename index_t>
struct bitset_filter {
// View of the bitset to use as a filter
const raft::core::bitset_view<bitset_t, index_t> bitset_view_;

bitset_filter(const raft::core::bitset_view<bitset_t, index_t> bitset_for_filtering)
: bitset_view_{bitset_for_filtering}
{
}
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const
{
return bitset_view_.test(sample_ix);
}
};
} // namespace raft::neighbors::filtering
Loading