Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support uint64_t in CAGRA index data type #1514

Merged
merged 28 commits into from
May 19, 2023
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
a4ca4a5
Support 64-bit index data type in CAGRA search
enp1s0 May 12, 2023
fd0291f
Add CAGRA test for 64-bit index data type
enp1s0 May 12, 2023
ad6b1e2
Merge branch 'rapidsai:branch-23.06' into cagra-64bit-index
enp1s0 May 12, 2023
4dec143
Fix a bug in CAGRA topk_cta_11_core
enp1s0 May 12, 2023
566f96d
Add const
enp1s0 May 12, 2023
76194ae
Fix the data type of seed index in CAGRA::random_pickup
enp1s0 May 15, 2023
89c6522
Merge branch 'cagra-64bit-index' of github.com:enp1s0/raft into cagra…
enp1s0 May 15, 2023
3cbd624
Fix indent
enp1s0 May 15, 2023
8a0cbfe
Update hashmap to support uint64
enp1s0 May 15, 2023
fb2cd93
Update "mottainai" bit to support uint64
enp1s0 May 15, 2023
57791b7
Fix cagra::prune
enp1s0 May 15, 2023
9448cba
Update CAGRA tests for uint64 index data type
enp1s0 May 15, 2023
ef5381d
Remove a comment
enp1s0 May 15, 2023
e790e82
Merge branch 'branch-23.06' into cagra-64bit-index
enp1s0 May 16, 2023
681c7d1
Add gen_index_msb_1_mask
enp1s0 May 17, 2023
9fa8c9e
Merge branch 'cagra-64bit-index' of github.com:enp1s0/raft into cagra…
enp1s0 May 17, 2023
4549430
Merge branch 'cagra-64bit-index' of github.com:enp1s0/raft into cagra…
enp1s0 May 17, 2023
0a68b43
Use gen_index_msb_1_mask in multi_cta::search_kernel
enp1s0 May 17, 2023
b2bd012
Use gen_index_msb_1_mask in multi_cta::search_kernel
enp1s0 May 17, 2023
1c26109
Fix code format
enp1s0 May 17, 2023
01473b6
Remove some CAGRA IdxT=uint64 tests
enp1s0 May 17, 2023
c91541d
Add support for int64 using uint64 kernels
enp1s0 May 18, 2023
98c3389
Merge branch 'rapidsai:branch-23.06' into cagra-64bit-index
enp1s0 May 18, 2023
c540fa2
Merge branch 'rapidsai:branch-23.06' into cagra-64bit-index
enp1s0 May 18, 2023
18927b6
Merge branch 'branch-23.06' into cagra-64bit-index
benfred May 18, 2023
cfbb29e
Update memory_type of knn graph in sort_knn_graph
enp1s0 May 19, 2023
6bd9548
Merge branch 'rapidsai:branch-23.06' into cagra-64bit-index
enp1s0 May 19, 2023
63228ad
Merge branch 'branch-23.06' into cagra-64bit-index
tfeher May 19, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,17 @@ void build_knn_graph(raft::resources const& res,
std::optional<ivf_pq::index_params> build_params = std::nullopt,
std::optional<ivf_pq::search_params> search_params = std::nullopt)
{
detail::build_knn_graph(res, dataset, knn_graph, refine_rate, build_params, search_params);
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

auto knn_graph_internal = make_host_matrix_view<internal_IdxT, internal_IdxT>(
reinterpret_cast<internal_IdxT*>(knn_graph.data_handle()),
knn_graph.extent(0),
knn_graph.extent(1));
auto dataset_internal = mdspan<const DataT, matrix_extent<internal_IdxT>, row_major, accessor>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));

detail::build_knn_graph(
res, dataset_internal, knn_graph_internal, refine_rate, build_params, search_params);
}

/**
Expand Down Expand Up @@ -124,7 +134,20 @@ void sort_knn_graph(raft::resources const& res,
mdspan<const DataT, matrix_extent<IdxT>, row_major, d_accessor> dataset,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph)
{
detail::graph::sort_knn_graph(res, dataset, knn_graph);
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

using g_accessor_internal =
host_device_accessor<std::experimental::default_accessor<internal_IdxT>, memory_type::host>;
enp1s0 marked this conversation as resolved.
Show resolved Hide resolved
auto knn_graph_internal =
mdspan<internal_IdxT, matrix_extent<internal_IdxT>, row_major, g_accessor_internal>(
reinterpret_cast<internal_IdxT*>(knn_graph.data_handle()),
knn_graph.extent(0),
knn_graph.extent(1));

auto dataset_internal = mdspan<const DataT, matrix_extent<internal_IdxT>, row_major, d_accessor>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1));

detail::graph::sort_knn_graph(res, dataset_internal, knn_graph_internal);
}

/**
Expand All @@ -148,7 +171,22 @@ void prune(raft::resources const& res,
mdspan<IdxT, matrix_extent<IdxT>, row_major, g_accessor> knn_graph,
raft::host_matrix_view<IdxT, IdxT, row_major> new_graph)
{
detail::graph::prune(res, knn_graph, new_graph);
using internal_IdxT = typename std::make_unsigned<IdxT>::type;

auto new_graph_internal = raft::make_host_matrix_view<internal_IdxT, internal_IdxT>(
reinterpret_cast<internal_IdxT*>(new_graph.data_handle()),
new_graph.extent(0),
new_graph.extent(1));

using g_accessor_internal =
host_device_accessor<std::experimental::default_accessor<internal_IdxT>, memory_type::host>;
auto knn_graph_internal =
mdspan<internal_IdxT, matrix_extent<internal_IdxT>, row_major, g_accessor_internal>(
reinterpret_cast<internal_IdxT*>(knn_graph.data_handle()),
knn_graph.extent(0),
knn_graph.extent(1));

detail::graph::prune(res, knn_graph_internal, new_graph_internal);
}

/**
Expand Down Expand Up @@ -200,7 +238,7 @@ index<T, IdxT> build(raft::resources const& res,
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset)
{
size_t degree = params.intermediate_graph_degree;
if (degree >= dataset.extent(0)) {
if (degree >= static_cast<size_t>(dataset.extent(0))) {
RAFT_LOG_WARN(
"Intermediate graph degree cannot be larger than dataset size, reducing it to %lu",
dataset.extent(0));
Expand Down Expand Up @@ -256,7 +294,18 @@ void search(raft::resources const& res,
RAFT_EXPECTS(queries.extent(1) == idx.dim(),
"Number of query dimensions should equal number of dimensions in the index.");

detail::search_main(res, params, idx, queries, neighbors, distances);
using internal_IdxT = typename std::make_unsigned<IdxT>::type;
auto queries_internal = raft::make_device_matrix_view<const T, internal_IdxT, row_major>(
queries.data_handle(), queries.extent(0), queries.extent(1));
auto neighbors_internal = raft::make_device_matrix_view<internal_IdxT, internal_IdxT, row_major>(
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()),
neighbors.extent(0),
neighbors.extent(1));
auto distances_internal = raft::make_device_matrix_view<float, internal_IdxT, row_major>(
distances.data_handle(), distances.extent(0), distances.extent(1));

detail::search_main<T, internal_IdxT, IdxT>(
res, params, idx, queries_internal, neighbors_internal, distances_internal);
}
/** @} */ // end group cagra

Expand Down
36 changes: 24 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ namespace raft::neighbors::experimental::cagra::detail {
* k]
*/

template <typename T, typename IdxT = uint32_t, typename DistanceT = float>
template <typename T, typename internal_IdxT, typename IdxT = uint32_t, typename DistanceT = float>
void search_main(raft::resources const& res,
search_params params,
const index<T, IdxT>& index,
raft::device_matrix_view<const T, IdxT, row_major> queries,
raft::device_matrix_view<IdxT, IdxT, row_major> neighbors,
raft::device_matrix_view<DistanceT, IdxT, row_major> distances)
raft::device_matrix_view<const T, internal_IdxT, row_major> queries,
raft::device_matrix_view<internal_IdxT, internal_IdxT, row_major> neighbors,
raft::device_matrix_view<DistanceT, internal_IdxT, row_major> distances)
{
RAFT_LOG_DEBUG("# dataset size = %lu, dim = %lu\n",
static_cast<size_t>(index.dataset().extent(0)),
Expand All @@ -69,8 +69,9 @@ void search_main(raft::resources const& res,
RAFT_EXPECTS(queries.extent(1) == index.dim(), "Querise and index dim must match");
uint32_t topk = neighbors.extent(1);

std::unique_ptr<search_plan_impl<T, IdxT, DistanceT>> plan =
factory<T, IdxT, DistanceT>::create(res, params, index.dim(), index.graph_degree(), topk);
std::unique_ptr<search_plan_impl<T, internal_IdxT, DistanceT>> plan =
factory<T, internal_IdxT, DistanceT>::create(
res, params, index.dim(), index.graph_degree(), topk);

plan->check(neighbors.extent(1));

Expand All @@ -79,18 +80,29 @@ void search_main(raft::resources const& res,
uint32_t query_dim = queries.extent(1);

for (unsigned qid = 0; qid < queries.extent(0); qid += max_queries) {
const uint32_t n_queries = std::min<std::size_t>(max_queries, queries.extent(0) - qid);
IdxT* _topk_indices_ptr = neighbors.data_handle() + (topk * qid);
const uint32_t n_queries = std::min<std::size_t>(max_queries, queries.extent(0) - qid);
internal_IdxT* _topk_indices_ptr =
reinterpret_cast<internal_IdxT*>(neighbors.data_handle()) + (topk * qid);
DistanceT* _topk_distances_ptr = distances.data_handle() + (topk * qid);
// todo(tfeher): one could keep distances optional and pass nullptr
const T* _query_ptr = queries.data_handle() + (query_dim * qid);
const IdxT* _seed_ptr =
plan->num_seeds > 0 ? plan->dev_seed.data() + (plan->num_seeds * qid) : nullptr;
const internal_IdxT* _seed_ptr =
plan->num_seeds > 0
? reinterpret_cast<const internal_IdxT*>(plan->dev_seed.data()) + (plan->num_seeds * qid)
: nullptr;
uint32_t* _num_executed_iterations = nullptr;

auto dataset_internal = raft::make_device_matrix_view<const T, internal_IdxT, row_major>(
index.dataset().data_handle(), index.dataset().extent(0), index.dataset().extent(1));
auto graph_internal =
raft::make_device_matrix_view<const internal_IdxT, internal_IdxT, row_major>(
reinterpret_cast<const internal_IdxT*>(index.graph().data_handle()),
index.graph().extent(0),
index.graph().extent(1));

(*plan)(res,
index.dataset(),
index.graph(),
dataset_internal,
graph_internal,
_topk_indices_ptr,
_topk_distances_ptr,
_query_ptr,
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/neighbors/detail/cagra/graph_core.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@

#include <raft/util/cuda_rt_essentials.hpp>

#include "utils.hpp"

namespace raft::neighbors::experimental::cagra::detail {
namespace graph {

Expand Down Expand Up @@ -115,7 +117,7 @@ __global__ void kern_sort(const DATA_T* const dataset, // [dataset_chunk_size,
my_vals[i] = smem_vals[k];
} else {
my_keys[i] = FLT_MAX;
my_vals[i] = ~static_cast<IdxT>(0);
my_vals[i] = utils::get_max_value<IdxT>();
}
}
__syncthreads();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ __device__ void pickup_next_parents(INDEX_T* const next_parent_indices, // [num
const size_t num_itopk,
uint32_t* const terminate_flag)
{
constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
const unsigned warp_id = threadIdx.x / 32;
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
const unsigned warp_id = threadIdx.x / 32;
if (warp_id > 0) { return; }
const unsigned lane_id = threadIdx.x % 32;
for (uint32_t i = lane_id; i < num_parents; i += 32) {
Expand Down Expand Up @@ -295,8 +294,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel(
uint32_t j = i + (itopk_size * (cta_id + (num_cta_per_query * query_id)));
if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[i]; }

constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

result_indices_ptr[j] =
result_indices_buffer[i] & ~index_msb_1_mask; // clear most significant bit
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,8 +212,8 @@ __global__ void pickup_next_parents_kernel(
const std::size_t parent_list_size, //
std::uint32_t* const terminate_flag)
{
constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

const std::size_t ldb = hashmap::get_size(hash_bitlen);
const uint32_t query_id = blockIdx.x;
if (threadIdx.x < 32) {
Expand Down Expand Up @@ -407,8 +407,7 @@ __global__ void remove_parent_bit_kernel(const std::uint32_t num_queries,
INDEX_T* const topk_indices_ptr, // [ld, num_queries]
const std::uint32_t ld)
{
constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

uint32_t i_query = blockIdx.x;
if (i_query >= num_queries) return;
Expand Down
13 changes: 6 additions & 7 deletions cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,7 @@ __device__ void pickup_next_parents(std::uint32_t* const terminate_flag,
const std::size_t dataset_size,
const std::uint32_t num_parents)
{
constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;
// if (threadIdx.x >= 32) return;

for (std::uint32_t i = threadIdx.x; i < num_parents; i += 32) {
Expand Down Expand Up @@ -505,8 +504,8 @@ __device__ inline void hashmap_restore(INDEX_T* const hashmap_ptr,
const INDEX_T* itopk_indices,
uint32_t itopk_size)
{
constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

if (threadIdx.x < FIRST_TID || threadIdx.x >= LAST_TID) return;
for (unsigned i = threadIdx.x - FIRST_TID; i < itopk_size; i += LAST_TID - FIRST_TID) {
auto key = itopk_indices[i] & ~index_msb_1_mask; // clear most significant bit
Expand Down Expand Up @@ -776,8 +775,8 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__
if (TOPK_BY_BITONIC_SORT) { ii = device::swizzling(i); }
if (result_distances_ptr != nullptr) { result_distances_ptr[j] = result_distances_buffer[ii]; }

constexpr INDEX_T index_msb_1_mask = static_cast<INDEX_T>(1)
<< (utils::size_of<INDEX_T>() * 8 - 1);
constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask<INDEX_T>::value;

result_indices_ptr[j] =
result_indices_buffer[ii] & ~index_msb_1_mask; // clear most significant bit
}
Expand Down Expand Up @@ -1124,7 +1123,7 @@ struct search : search_plan_impl<DATA_T, INDEX_T, DISTANCE_T> {
hashmap_size = 0;
if (small_hash_bitlen == 0) {
hashmap_size = sizeof(INDEX_T) * max_queries * hashmap::get_size(hash_bitlen);
hashmap.resize(hashmap_size, res.get_stream());
hashmap.resize(hashmap_size, resource::get_cuda_stream(res));
}
RAFT_LOG_DEBUG("# hashmap_size: %lu", hashmap_size);
}
Expand Down
5 changes: 5 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ template <int A, int B>
struct constexpr_max<A, B, std::enable_if_t<(B > A), bool>> {
static const int value = B;
};

template <class IdxT>
struct gen_index_msb_1_mask {
static constexpr IdxT value = static_cast<IdxT>(1) << (utils::size_of<IdxT>() * 8 - 1);
};
} // namespace utils

} // namespace raft::neighbors::experimental::cagra::detail
4 changes: 1 addition & 3 deletions cpp/test/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,7 @@ if(BUILD_TESTS)
test/neighbors/ann_cagra/test_float_uint32_t.cu
test/neighbors/ann_cagra/test_int8_t_uint32_t.cu
test/neighbors/ann_cagra/test_uint8_t_uint32_t.cu
test/neighbors/ann_cagra/test_float_uint64_t.cu
test/neighbors/ann_cagra/test_int8_t_uint64_t.cu
test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu
test/neighbors/ann_cagra/test_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_float_int64_t.cu
test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu
test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

namespace raft::neighbors::experimental::cagra {

typedef AnnCagraTest<float, float, std::uint64_t> AnnCagraTestF_I64;
typedef AnnCagraTest<float, float, std::int64_t> AnnCagraTestF_I64;
TEST_P(AnnCagraTestF_I64, AnnCagra) { this->testCagra(); }

INSTANTIATE_TEST_CASE_P(AnnCagraTest, AnnCagraTestF_I64, ::testing::ValuesIn(inputs));
Expand Down
32 changes: 0 additions & 32 deletions cpp/test/neighbors/ann_cagra/test_int8_t_uint64_t.cu

This file was deleted.

33 changes: 0 additions & 33 deletions cpp/test/neighbors/ann_cagra/test_uint8_t_uint64_t.cu

This file was deleted.