Skip to content

Commit

Permalink
Merge pull request #1873 from AyodeAwe/branch-23.12-merge-23.10
Browse files Browse the repository at this point in the history
Branch 23.12 merge 23.10
  • Loading branch information
raydouglass authored Oct 5, 2023
2 parents 2a7a869 + 90bc2b1 commit de21b85
Show file tree
Hide file tree
Showing 11 changed files with 289 additions and 22 deletions.
52 changes: 45 additions & 7 deletions cpp/bench/prims/neighbors/cagra_bench.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@

#include <common/benchmark.hpp>
#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/sample_filter.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>
#include <thrust/sequence.h>

#include <optional>

Expand All @@ -40,6 +42,8 @@ struct params {
int block_size;
int search_width;
int max_iterations;
/** Ratio of removed indices. */
double removed_ratio;
};

template <typename T, typename IdxT>
Expand All @@ -49,7 +53,8 @@ struct CagraBench : public fixture {
params_(ps),
queries_(make_device_matrix<T, int64_t>(handle, ps.n_queries, ps.n_dims)),
dataset_(make_device_matrix<T, int64_t>(handle, ps.n_samples, ps.n_dims)),
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree))
knn_graph_(make_device_matrix<IdxT, int64_t>(handle, ps.n_samples, ps.degree)),
removed_indices_bitset_(handle, ps.n_samples)
{
// Generate random dataset and queriees
raft::random::RngState state{42};
Expand All @@ -74,6 +79,13 @@ struct CagraBench : public fixture {

auto metric = raft::distance::DistanceType::L2Expanded;

auto removed_indices =
raft::make_device_vector<IdxT, int64_t>(handle, ps.removed_ratio * ps.n_samples);
thrust::sequence(
resource::get_thrust_policy(handle),
thrust::device_pointer_cast(removed_indices.data_handle()),
thrust::device_pointer_cast(removed_indices.data_handle() + removed_indices.extent(0)));
removed_indices_bitset_.set(handle, removed_indices.view());
index_.emplace(raft::neighbors::cagra::index<T, IdxT>(
handle, metric, make_const_mdspan(dataset_.view()), make_const_mdspan(knn_graph_.view())));
}
Expand All @@ -95,10 +107,18 @@ struct CagraBench : public fixture {
distances.data_handle(), params_.n_queries, params_.k);

auto queries_v = make_const_mdspan(queries_.view());
loop_on_state(state, [&]() {
raft::neighbors::cagra::search(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
});
if (params_.removed_ratio > 0) {
auto filter = raft::neighbors::filtering::bitset_filter(removed_indices_bitset_.view());
loop_on_state(state, [&]() {
raft::neighbors::cagra::search_with_filtering(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v, filter);
});
} else {
loop_on_state(state, [&]() {
raft::neighbors::cagra::search(
this->handle, search_params, *this->index_, queries_v, ind_v, dist_v);
});
}

double data_size = params_.n_samples * params_.n_dims * sizeof(T);
double graph_size = params_.n_samples * params_.degree * sizeof(IdxT);
Expand All @@ -120,6 +140,7 @@ struct CagraBench : public fixture {
state.counters["block_size"] = params_.block_size;
state.counters["search_width"] = params_.search_width;
state.counters["iterations"] = iterations;
state.counters["removed_ratio"] = params_.removed_ratio;
}

private:
Expand All @@ -128,6 +149,7 @@ struct CagraBench : public fixture {
raft::device_matrix<T, int64_t, row_major> queries_;
raft::device_matrix<T, int64_t, row_major> dataset_;
raft::device_matrix<IdxT, int64_t, row_major> knn_graph_;
raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset_;
};

inline const std::vector<params> generate_inputs()
Expand All @@ -141,7 +163,8 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{0}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
auto inputs2 = raft::util::itertools::product<params>({2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
Expand All @@ -151,7 +174,22 @@ inline const std::vector<params> generate_inputs()
{64}, // itopk_size
{64, 128, 256, 512, 1024}, // block_size
{1}, // search_width
{0} // max_iterations
{0}, // max_iterations
{0.0} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

inputs2 = raft::util::itertools::product<params>(
{2000000ull, 10000000ull}, // n_samples
{128}, // dataset dim
{1, 10, 10000}, // n_queries
{255}, // k
{64}, // knn graph degree
{300}, // itopk_size
{256}, // block_size
{2}, // search_width
{0}, // max_iterations
{0.0, 0.02, 0.04, 0.08, 0.16, 0.32, 0.64} // removed_ratio
);
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
return inputs;
Expand Down
25 changes: 23 additions & 2 deletions cpp/include/raft/neighbors/brute_force_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,11 +66,11 @@ struct index : ann::index {
/** Dataset norms */
[[nodiscard]] inline auto norms() const -> device_vector_view<const T, int64_t, row_major>
{
return make_const_mdspan(norms_.value().view());
return norms_view_.value();
}

/** Whether ot not this index has dataset norms */
[[nodiscard]] inline bool has_norms() const noexcept { return norms_.has_value(); }
[[nodiscard]] inline bool has_norms() const noexcept { return norms_view_.has_value(); }

[[nodiscard]] inline T metric_arg() const noexcept { return metric_arg_; }

Expand Down Expand Up @@ -102,10 +102,30 @@ struct index : ann::index {
norms_(std::move(norms)),
metric_arg_(metric_arg)
{
if (norms_) { norms_view_ = make_const_mdspan(norms_.value().view()); }
update_dataset(res, dataset);
resource::sync_stream(res);
}

/** Construct a brute force index from dataset
*
* This class stores a non-owning reference to the dataset and norms here.
* Having precomputed norms gives us a performance advantage at query time.
*/
index(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset_view,
std::optional<raft::device_vector_view<const T, int64_t>> norms_view,
raft::distance::DistanceType metric,
T metric_arg = 0.0)
: ann::index(),
metric_(metric),
dataset_(make_device_matrix<T, int64_t>(res, 0, 0)),
dataset_view_(dataset_view),
norms_view_(norms_view),
metric_arg_(metric_arg)
{
}

private:
/**
* Replace the dataset with a new dataset.
Expand Down Expand Up @@ -135,6 +155,7 @@ struct index : ann::index {
raft::distance::DistanceType metric_;
raft::device_matrix<T, int64_t, row_major> dataset_;
std::optional<raft::device_vector<T, int64_t>> norms_;
std::optional<raft::device_vector_view<const T, int64_t>> norms_view_;
raft::device_matrix_view<const T, int64_t, row_major> dataset_view_;
T metric_arg_;
};
Expand Down
20 changes: 19 additions & 1 deletion cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,25 @@ void search(raft::resources const& res,
/**
* @brief Search ANN using the constructed index with the given sample filter.
*
* See the [cagra::build](#cagra::build) documentation for a usage example.
* Usage example:
* @code{.cpp}
* using namespace raft::neighbors;
* // use default index parameters
* cagra::index_params index_params;
* // create and fill the index from a [N, D] dataset
* auto index = cagra::build(res, index_params, dataset);
* // use default search parameters
* cagra::search_params search_params;
* // create a bitset to filter the search
* auto removed_indices = raft::make_device_vector<IdxT>(res, n_removed_indices);
* raft::core::bitset<std::uint32_t, IdxT> removed_indices_bitset(
* res, removed_indices.view(), dataset.extent(0));
* // search K nearest neighbours according to a bitset
* auto neighbors = raft::make_device_matrix<uint32_t>(res, n_queries, k);
* auto distances = raft::make_device_matrix<float>(res, n_queries, k);
* cagra::search_with_filtering(res, search_params, index, queries, neighbors, distances,
* filtering::bitset_filter(removed_indices_bitset.view()));
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices
Expand Down
23 changes: 16 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -478,13 +478,15 @@ __global__ void apply_filter_kernel(INDEX_T* const result_indices_ptr,
const INDEX_T query_id_offset,
SAMPLE_FILTER_T sample_filter)
{
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const auto tid = threadIdx.x + blockIdx.x * blockDim.x;
if (tid >= result_buffer_size * num_queries) { return; }
const auto i = tid % result_buffer_size;
const auto j = tid / result_buffer_size;
const auto index = i + j * lds;

if (!sample_filter(query_id_offset + j, result_indices_ptr[index])) {
if (result_indices_ptr[index] != ~index_msb_1_mask &&
!sample_filter(query_id_offset + j, result_indices_ptr[index])) {
result_indices_ptr[index] = utils::get_max_value<INDEX_T>();
result_distances_ptr[index] = utils::get_max_value<DISTANCE_T>();
}
Expand Down Expand Up @@ -788,12 +790,15 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
auto result_indices_ptr = result_indices.data() + (iter & 0x1) * result_buffer_size;
auto result_distances_ptr = result_distances.data() + (iter & 0x1) * result_buffer_size;

// Remove parent bit in search results
remove_parent_bit(
num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream);
if constexpr (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
// Remove parent bit in search results
remove_parent_bit(num_queries,
result_buffer_size,
result_indices.data() + (iter & 0x1) * itopk_size,
result_buffer_allocation_size,
stream);

if (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
apply_filter<INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>(
result_indices.data() + (iter & 0x1) * itopk_size,
result_distances.data() + (iter & 0x1) * itopk_size,
Expand Down Expand Up @@ -821,6 +826,10 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T> {
true,
topk_hint.data(),
stream);
} else {
// Remove parent bit in search results
remove_parent_bit(
num_queries, itopk_size, result_indices_ptr, result_buffer_allocation_size, stream);
}

// Copy results from working buffer to final buffer
Expand Down
8 changes: 8 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/search_plan.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,14 @@ struct search_plan_impl : public search_plan_impl_base {
"`hashmap_max_fill_rate` must be equal to or greater than 0.1 and smaller than 0.9. " +
std::to_string(hashmap_max_fill_rate) + " has been given.";
}
if constexpr (!std::is_same<SAMPLE_FILTER_T,
raft::neighbors::filtering::none_cagra_sample_filter>::value) {
if (hashmap_mode == hash_mode::SMALL) {
error_message += "`SMALL` hash is not available when filtering";
} else {
hashmap_mode = hash_mode::HASH;
}
}
if (algo == search_algo::MULTI_CTA) {
if (hashmap_mode == hash_mode::SMALL) {
error_message += "`small_hash` is not available when 'search_mode' is \"multi-cta\"";
Expand Down
3 changes: 1 addition & 2 deletions cpp/include/raft/neighbors/detail/nn_descent.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -1278,8 +1278,7 @@ void GNND<Data_t, Index_t>::build(Data_t* data, const Index_t nrow, Index_t* out
std::thread update_and_sample_thread(update_and_sample, it);
std::cout << "# GNND iteraton: " << it + 1 << "/" << build_config_.max_iterations << "\r";
std::fflush(stdout);
RAFT_LOG_DEBUG("# GNND iteraton: %lu / %lu", it + 1, build_config_.max_iterations);
// Reuse dists_buffer_ to save GPU memory. graph_buffer_ cannot be reused, because it
// contains some information for local_join.
Expand Down
48 changes: 48 additions & 0 deletions cpp/include/raft/neighbors/sample_filter.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cstddef>
#include <cstdint>

#include <raft/core/bitset.cuh>

namespace raft::neighbors::filtering {
/**
* @brief Filter an index with a bitset
*
* @tparam index_t Indexing type
*/
template <typename bitset_t, typename index_t>
struct bitset_filter {
// View of the bitset to use as a filter
const raft::core::bitset_view<bitset_t, index_t> bitset_view_;

bitset_filter(const raft::core::bitset_view<bitset_t, index_t> bitset_for_filtering)
: bitset_view_{bitset_for_filtering}
{
}
inline _RAFT_HOST_DEVICE bool operator()(
// query index
const uint32_t query_ix,
// the index of the current sample
const uint32_t sample_ix) const
{
return bitset_view_.test(sample_ix);
}
};
} // namespace raft::neighbors::filtering
Loading

0 comments on commit de21b85

Please sign in to comment.