Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reapply: Support for fp16 in CAGRA and IVF-PQ #2172

Merged
merged 3 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu
Expand All @@ -358,6 +362,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu
Expand All @@ -367,6 +375,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu
src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu
src/neighbors/detail/ivf_flat_search.cu
Expand All @@ -378,6 +387,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu
src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu
src/neighbors/detail/refine_host_float_float.cpp
src/neighbors/detail/refine_host_half_float.cpp
src/neighbors/detail/refine_host_int8_t_float.cpp
src/neighbors/detail/refine_host_uint8_t_float.cpp
src/neighbors/ivf_flat_build_float_int64_t.cu
Expand All @@ -390,15 +400,19 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/ivf_flat_search_int8_t_int64_t.cu
src/neighbors/ivf_flat_search_uint8_t_int64_t.cu
src/neighbors/ivfpq_build_float_int64_t.cu
src/neighbors/ivfpq_build_half_int64_t.cu
src/neighbors/ivfpq_build_int8_t_int64_t.cu
src/neighbors/ivfpq_build_uint8_t_int64_t.cu
src/neighbors/ivfpq_extend_float_int64_t.cu
src/neighbors/ivfpq_extend_half_int64_t.cu
src/neighbors/ivfpq_extend_int8_t_int64_t.cu
src/neighbors/ivfpq_extend_uint8_t_int64_t.cu
src/neighbors/ivfpq_search_float_int64_t.cu
src/neighbors/ivfpq_search_half_int64_t.cu
src/neighbors/ivfpq_search_int8_t_int64_t.cu
src/neighbors/ivfpq_search_uint8_t_int64_t.cu
src/neighbors/refine_float_float.cu
src/neighbors/refine_half_float.cu
src/neighbors/refine_int8_t_float.cu
src/neighbors/refine_uint8_t_float.cu
src/raft_runtime/cluster/cluster_cost.cuh
Expand Down
14 changes: 13 additions & 1 deletion cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

#if defined(_RAFT_HAS_CUDA)
#include <cuda_fp16.h>
#endif

#include <algorithm>
#include <complex>
#include <cstdint>
Expand Down Expand Up @@ -121,6 +125,14 @@ inline dtype_t get_numpy_dtype()
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'f', sizeof(T)};
}

#if defined(_RAFT_HAS_CUDA)
template <typename T, typename std::enable_if_t<std::is_same_v<T, half>, bool> = true>
inline dtype_t get_numpy_dtype()
{
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)};
}
#endif

template <typename T,
typename std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>, bool> = true>
inline dtype_t get_numpy_dtype()
Expand Down Expand Up @@ -273,7 +285,7 @@ inline dtype_t parse_descr(std::string typestr)

const char endian_chars[] = {
RAFT_NUMPY_LITTLE_ENDIAN_CHAR, RAFT_NUMPY_BIG_ENDIAN_CHAR, RAFT_NUMPY_NO_ENDIAN_CHAR};
const char numtype_chars[] = {'f', 'i', 'u', 'c'};
const char numtype_chars[] = {'f', 'i', 'u', 'c', 'e'};

RAFT_EXPECTS(std::find(std::begin(endian_chars), std::end(endian_chars), byteorder_c) !=
std::end(endian_chars),
Expand Down
238 changes: 124 additions & 114 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
@@ -1,114 +1,124 @@
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/neighbors/sample_filter_types.hpp> // none_cagra_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class SAMPLE_FILTER_T>
void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
INDEX_T* const topk_indices_ptr,
DISTANCE_T* const topk_distances_ptr,
const DATA_T* const queries_ptr,
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
uint32_t topk,
uint32_t block_size,
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_kernel_selection( \
TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \
extern template void \
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>( \
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset, \
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph, \
INDEX_T* const topk_indices_ptr, \
DISTANCE_T* const topk_distances_ptr, \
const DATA_T* const queries_ptr, \
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cudaStream_t stream);

instantiate_kernel_selection(
32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);

#undef instantiate_kernel_selection
} // namespace multi_cta_search
} // namespace raft::neighbors::cagra::detail
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/neighbors/sample_filter_types.hpp> // none_cagra_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#include <cuda_fp16.h>

namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class SAMPLE_FILTER_T>
void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
INDEX_T* const topk_indices_ptr,
DISTANCE_T* const topk_distances_ptr,
const DATA_T* const queries_ptr,
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
uint32_t topk,
uint32_t block_size,
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_kernel_selection( \
TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \
extern template void \
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>( \
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset, \
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph, \
INDEX_T* const topk_indices_ptr, \
DISTANCE_T* const topk_distances_ptr, \
const DATA_T* const queries_ptr, \
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cudaStream_t stream);

instantiate_kernel_selection(
32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);

#undef instantiate_kernel_selection
} // namespace multi_cta_search
} // namespace raft::neighbors::cagra::detail
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
#include <raft/neighbors/sample_filter_types.hpp>
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#include <cuda_fp16.h>

namespace raft::neighbors::cagra::detail {
namespace single_cta_search {

Expand Down Expand Up @@ -96,6 +98,14 @@ instantiate_single_cta_select_and_run(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_single_cta_select_and_run(
Expand Down
Loading
Loading