Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA-Q search #2206

Merged
merged 36 commits into from
Mar 21, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
35f2971
Rebase the PR to enable CI
enp1s0 Mar 18, 2024
dab54e0
Fix broken usage of index.dataset() and the like after the merge of t…
achirkin Mar 18, 2024
22709e2
Add explicit instantiations for IVF-PQ search kernels used in tests (…
tfeher Mar 18, 2024
80c45f1
Fix style errors
achirkin Mar 18, 2024
49b6b61
Add explicit instantiations for IVF-PQ search kernels used in tests (…
tfeher Mar 18, 2024
2c02881
Merge branch 'branch-24.04' into cagra-q
achirkin Mar 18, 2024
a69f66d
Add uint32 VPQ test
enp1s0 Mar 19, 2024
2d35446
Update VPQ test
enp1s0 Mar 19, 2024
711d3d8
Merge branch 'branch-24.04' into cagra-q
achirkin Mar 19, 2024
678f767
Cleanup the tests a little bit, add a sanity check for the index type
achirkin Mar 19, 2024
09db075
Update cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
achirkin Mar 19, 2024
4eca385
Output an error message if multi_kernel and vpq are specified
enp1s0 Mar 19, 2024
d321873
Use a raft helper for ceiling division
achirkin Mar 19, 2024
11f9350
Move the instance macro to a separate header to reduce the codesize
achirkin Mar 19, 2024
3ff9382
Merge branch 'branch-24.04' into cagra-q
achirkin Mar 19, 2024
38a8bf2
Use TxN_t
enp1s0 Mar 19, 2024
cda2cb8
Fix incorrect addressing using TxN_t
achirkin Mar 19, 2024
ff7d3b2
Merge branch 'branch-24.04' into cagra-q
achirkin Mar 19, 2024
103b9c0
Fix typo
enp1s0 Mar 20, 2024
1fb7c36
Fix VPQ search params validation
enp1s0 Mar 20, 2024
89aa91e
Add dim size validation
enp1s0 Mar 20, 2024
daf4f08
Fix VPQ similarity computation for large dim
enp1s0 Mar 20, 2024
38ab2bd
Update CAGRA VPQ test
enp1s0 Mar 20, 2024
15afe26
Merge branch 'branch-24.04' into cagra-q
enp1s0 Mar 20, 2024
5174811
Update cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh
enp1s0 Mar 20, 2024
16ddb13
Remove redundant team-size and dataset-block-dim parameters from the …
achirkin Mar 20, 2024
317c67f
Mark the strided_dataset::view as deleted (pure virtual) to avoid lin…
achirkin Mar 20, 2024
59033c7
Fix the instances in the tests as well
achirkin Mar 20, 2024
6567186
Fix a bug in VPQ similarity compute
enp1s0 Mar 20, 2024
ecb896c
Disable implicit template instantiations for vpq tests
tfeher Mar 20, 2024
1308c61
cagra-vpq enable instantiation of int64 kernels
tfeher Mar 20, 2024
6d663ae
Correct copyright year
tfeher Mar 20, 2024
0e29876
Update query copy from dmem to smem
enp1s0 Mar 20, 2024
31b6982
Merge branch 'cagra-q' of github.com:enp1s0/raft into cagra-q
enp1s0 Mar 20, 2024
6ebb99e
Fix query mapping type and usage of a macro that is not available on …
achirkin Mar 20, 2024
b2cdb6d
Set pq_len=2 as default, do not allow different pq_len for search
tfeher Mar 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,21 @@ struct standard_dataset_descriptor_t
static const std::uint32_t smem_buffer_size_in_byte = 0;
__device__ void set_smem_ptr(void* const){};

template <uint32_t DATASET_BLOCK_DIM>
__device__ void copy_query(const DATA_T* const dmem_query_ptr,
QUERY_T* const smem_query_ptr,
const std::uint32_t query_smem_buffer_length)
{
for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) {
unsigned j = device::swizzling(i);
if (i < dim) {
smem_query_ptr[j] = spatial::knn::detail::utils::mapping<float>{}(dmem_query_ptr[i]);
} else {
smem_query_ptr[j] = 0.0;
}
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T dataset_i,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,26 @@ struct cagra_q_dataset_descriptor_t : public dataset_descriptor_base_t<half, DIS
{
}

template <uint32_t DATASET_BLOCK_DIM>
__device__ void copy_query(const DATA_T* const dmem_query_ptr,
QUERY_T* const smem_query_ptr,
const std::uint32_t query_smem_buffer_length)
{
for (unsigned i = threadIdx.x * 2; i < dim; i += blockDim.x * 2) {
half2 buf2 = {CUDART_ZERO_FP16, CUDART_ZERO_FP16};
achirkin marked this conversation as resolved.
Show resolved Hide resolved
if (i < dim) { buf2.x = static_cast<half>(static_cast<float>(dmem_query_ptr[i])); }
achirkin marked this conversation as resolved.
Show resolved Hide resolved
if (i + 1 < dim) { buf2.y = static_cast<half>(static_cast<float>(dmem_query_ptr[i + 1])); }
if ((PQ_BITS == 8) && (PQ_LEN % 2 == 0)) {
// Use swizzling in the condition to reduce bank conflicts in shared
// memory, which are likely to occur when pq_code_book_dim is large.
((half2*)smem_query_ptr)[device::swizzling<std::uint32_t, DATASET_BLOCK_DIM / 2>(i / 2)] =
buf2;
} else {
(reinterpret_cast<half2*>(smem_query_ptr + i))[0] = buf2;
}
}
}

template <uint32_t DATASET_BLOCK_DIM, uint32_t TEAM_SIZE>
__device__ DISTANCE_T compute_similarity(const QUERY_T* const query_ptr,
const INDEX_T node_id,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -213,14 +213,9 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
}
#endif
const DATA_T* const query_ptr = queries_ptr + (dataset_desc.dim * query_id);
for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) {
unsigned j = device::swizzling(i);
if (i < dataset_desc.dim) {
query_buffer[j] = spatial::knn::detail::utils::mapping<float>{}(query_ptr[i]);
} else {
query_buffer[j] = 0.0;
}
}
dataset_desc.template copy_query<DATASET_BLOCK_DIM>(
query_ptr, query_buffer, query_smem_buffer_length);

if (threadIdx.x == 0) { terminate_flag[0] = 0; }
INDEX_T* const local_visited_hashmap_ptr =
visited_hashmap_ptr + (hashmap::get_size(hash_bitlen) * query_id);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -549,14 +549,9 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel(
auto filter_flag = terminate_flag;

const DATA_T* const query_ptr = queries_ptr + query_id * dataset_desc.dim;
for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) {
unsigned j = device::swizzling(i);
if (i < dataset_desc.dim) {
query_buffer[j] = spatial::knn::detail::utils::mapping<float>{}(query_ptr[i]);
} else {
query_buffer[j] = 0.0;
}
}
dataset_desc.template copy_query<DATASET_BLOCK_DIM>(
query_ptr, query_buffer, query_smem_buffer_length);

if (threadIdx.x == 0) {
terminate_flag[0] = 0;
topk_ws[0] = ~0u;
Expand Down