Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for fp16 in CAGRA and IVF-PQ #2085

Merged
merged 22 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
94ddf0c
Initial support and specializations
achirkin Jan 10, 2024
ffa0be9
Add missing device_loads_stores overloads
achirkin Jan 10, 2024
756d03e
Tweak the uniformInt range to a valid value in tests
achirkin Jan 10, 2024
adf8e09
Merge branch 'branch-24.02' into fea-cagra-fp16
achirkin Jan 10, 2024
5ab5304
Undo copyright-only changes introduced by generator scripts
achirkin Jan 10, 2024
bda13b1
Undo copyright-only changes introduced by generator scripts
achirkin Jan 10, 2024
d75fccc
Add -v flag for more verbose build to debug the linker error
achirkin Jan 10, 2024
3895754
Update build_libraft.sh
achirkin Jan 10, 2024
a776329
Make the tests PIE and remove object files that are already present i…
achirkin Jan 11, 2024
ec4c1e6
Add PIC to the raft_lib and raft_lib_static
achirkin Jan 11, 2024
7bbff83
Mirror the previous CAGRA float tests more closely with the new half …
achirkin Jan 11, 2024
1af6ebb
Revert debug changes in build_libraft.sh
achirkin Jan 11, 2024
23bad7e
Merge branch 'branch-24.02' into fea-cagra-fp16
achirkin Jan 11, 2024
96bf1a6
Merge branch 'branch-24.02' into fea-cagra-fp16
achirkin Jan 13, 2024
b28a9b7
Merge branch 'branch-24.02' into fea-cagra-fp16
achirkin Jan 17, 2024
6947fb5
Merge branch 'branch-24.02' into fea-cagra-fp16
achirkin Jan 17, 2024
1a60963
Update cpp/test/neighbors/ann_cagra.cuh
achirkin Jan 18, 2024
602516f
Update cpp/test/neighbors/ann_cagra.cuh
achirkin Jan 18, 2024
bdaecbb
Update cpp/test/neighbors/ann_cagra.cuh
achirkin Jan 18, 2024
48c1454
Skip the tests that may fail due to fp-rounding errors
achirkin Jan 18, 2024
274cd40
Add a few more sts overloads
achirkin Jan 18, 2024
3aa2a75
Merge branch 'branch-24.02' into fea-cagra-fp16
achirkin Jan 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -338,6 +338,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu
Expand All @@ -350,6 +354,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu
Expand All @@ -359,6 +367,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu
src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu
src/neighbors/detail/ivf_flat_search.cu
Expand All @@ -370,6 +379,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu
src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu
src/neighbors/detail/refine_host_float_float.cpp
src/neighbors/detail/refine_host_half_float.cpp
src/neighbors/detail/refine_host_int8_t_float.cpp
src/neighbors/detail/refine_host_uint8_t_float.cpp
src/neighbors/ivf_flat_build_float_int64_t.cu
Expand All @@ -382,15 +392,19 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/ivf_flat_search_int8_t_int64_t.cu
src/neighbors/ivf_flat_search_uint8_t_int64_t.cu
src/neighbors/ivfpq_build_float_int64_t.cu
src/neighbors/ivfpq_build_half_int64_t.cu
src/neighbors/ivfpq_build_int8_t_int64_t.cu
src/neighbors/ivfpq_build_uint8_t_int64_t.cu
src/neighbors/ivfpq_extend_float_int64_t.cu
src/neighbors/ivfpq_extend_half_int64_t.cu
src/neighbors/ivfpq_extend_int8_t_int64_t.cu
src/neighbors/ivfpq_extend_uint8_t_int64_t.cu
src/neighbors/ivfpq_search_float_int64_t.cu
src/neighbors/ivfpq_search_half_int64_t.cu
src/neighbors/ivfpq_search_int8_t_int64_t.cu
src/neighbors/ivfpq_search_uint8_t_int64_t.cu
src/neighbors/refine_float_float.cu
src/neighbors/refine_half_float.cu
src/neighbors/refine_int8_t_float.cu
src/neighbors/refine_uint8_t_float.cu
src/raft_runtime/cluster/cluster_cost.cuh
Expand Down
16 changes: 14 additions & 2 deletions cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,10 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

#if defined(_RAFT_HAS_CUDA)
#include <cuda_fp16.h>
#endif

#include <algorithm>
#include <complex>
#include <cstdint>
Expand Down Expand Up @@ -121,6 +125,14 @@ inline dtype_t get_numpy_dtype()
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'f', sizeof(T)};
}

#if defined(_RAFT_HAS_CUDA)
template <typename T, typename std::enable_if_t<std::is_same_v<T, half>, bool> = true>
inline dtype_t get_numpy_dtype()
{
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)};
}
#endif

template <typename T,
typename std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>, bool> = true>
inline dtype_t get_numpy_dtype()
Expand Down Expand Up @@ -273,7 +285,7 @@ inline dtype_t parse_descr(std::string typestr)

const char endian_chars[] = {
RAFT_NUMPY_LITTLE_ENDIAN_CHAR, RAFT_NUMPY_BIG_ENDIAN_CHAR, RAFT_NUMPY_NO_ENDIAN_CHAR};
const char numtype_chars[] = {'f', 'i', 'u', 'c'};
const char numtype_chars[] = {'f', 'i', 'u', 'c', 'e'};

RAFT_EXPECTS(std::find(std::begin(endian_chars), std::end(endian_chars), byteorder_c) !=
std::end(endian_chars),
Expand Down
238 changes: 124 additions & 114 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
@@ -1,114 +1,124 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/neighbors/sample_filter_types.hpp> // none_cagra_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class SAMPLE_FILTER_T>
void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
INDEX_T* const topk_indices_ptr,
DISTANCE_T* const topk_distances_ptr,
const DATA_T* const queries_ptr,
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
uint32_t topk,
uint32_t block_size,
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_kernel_selection( \
TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \
extern template void \
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>( \
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset, \
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph, \
INDEX_T* const topk_indices_ptr, \
DISTANCE_T* const topk_distances_ptr, \
const DATA_T* const queries_ptr, \
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cudaStream_t stream);

instantiate_kernel_selection(
32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);

#undef instantiate_kernel_selection
} // namespace multi_cta_search
} // namespace raft::neighbors::cagra::detail
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/neighbors/sample_filter_types.hpp> // none_cagra_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#include <cuda_fp16.h>

namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class SAMPLE_FILTER_T>
void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
INDEX_T* const topk_indices_ptr,
DISTANCE_T* const topk_distances_ptr,
const DATA_T* const queries_ptr,
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
uint32_t topk,
uint32_t block_size,
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_kernel_selection( \
TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \
extern template void \
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>( \
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset, \
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph, \
INDEX_T* const topk_indices_ptr, \
DISTANCE_T* const topk_distances_ptr, \
const DATA_T* const queries_ptr, \
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cudaStream_t stream);

instantiate_kernel_selection(
32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);

#undef instantiate_kernel_selection
} // namespace multi_cta_search
} // namespace raft::neighbors::cagra::detail
Loading