Skip to content

Commit

Permalink
Support for fp16 in CAGRA and IVF-PQ (#2085)
Browse files Browse the repository at this point in the history
Add fp16 (CUDA half) support to CAGRA and its dependencies.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)
  - tsuki (https://github.com/enp1s0)

URL: #2085
  • Loading branch information
achirkin authored Jan 19, 2024
1 parent 7cab0c3 commit 72f48ae
Show file tree
Hide file tree
Showing 44 changed files with 1,813 additions and 153 deletions.
16 changes: 15 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# =============================================================================
# Copyright (c) 2020-2023, NVIDIA CORPORATION.
# Copyright (c) 2020-2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
Expand Down Expand Up @@ -338,6 +338,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_multi_cta_half_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_multi_cta_int8_uint32_dim512_t32.cu
Expand All @@ -350,6 +354,10 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_float_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_half_uint32_dim1024_t32.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim128_t8.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim256_t16.cu
src/neighbors/detail/cagra/search_single_cta_int8_uint32_dim512_t32.cu
Expand All @@ -359,6 +367,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim512_t32.cu
src/neighbors/detail/cagra/search_single_cta_uint8_uint32_dim1024_t32.cu
src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_half_half_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu
src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu
src/neighbors/detail/ivf_flat_search.cu
Expand All @@ -370,6 +379,7 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu
src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu
src/neighbors/detail/refine_host_float_float.cpp
src/neighbors/detail/refine_host_half_float.cpp
src/neighbors/detail/refine_host_int8_t_float.cpp
src/neighbors/detail/refine_host_uint8_t_float.cpp
src/neighbors/ivf_flat_build_float_int64_t.cu
Expand All @@ -382,15 +392,19 @@ if(RAFT_COMPILE_LIBRARY)
src/neighbors/ivf_flat_search_int8_t_int64_t.cu
src/neighbors/ivf_flat_search_uint8_t_int64_t.cu
src/neighbors/ivfpq_build_float_int64_t.cu
src/neighbors/ivfpq_build_half_int64_t.cu
src/neighbors/ivfpq_build_int8_t_int64_t.cu
src/neighbors/ivfpq_build_uint8_t_int64_t.cu
src/neighbors/ivfpq_extend_float_int64_t.cu
src/neighbors/ivfpq_extend_half_int64_t.cu
src/neighbors/ivfpq_extend_int8_t_int64_t.cu
src/neighbors/ivfpq_extend_uint8_t_int64_t.cu
src/neighbors/ivfpq_search_float_int64_t.cu
src/neighbors/ivfpq_search_half_int64_t.cu
src/neighbors/ivfpq_search_int8_t_int64_t.cu
src/neighbors/ivfpq_search_uint8_t_int64_t.cu
src/neighbors/refine_float_float.cu
src/neighbors/refine_half_float.cu
src/neighbors/refine_int8_t_float.cu
src/neighbors/refine_uint8_t_float.cu
src/raft_runtime/cluster/cluster_cost.cuh
Expand Down
16 changes: 14 additions & 2 deletions cpp/include/raft/core/detail/mdspan_numpy_serializer.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,10 @@
#include <raft/core/host_mdspan.hpp>
#include <raft/core/resources.hpp>

#if defined(_RAFT_HAS_CUDA)
#include <cuda_fp16.h>
#endif

#include <algorithm>
#include <complex>
#include <cstdint>
Expand Down Expand Up @@ -121,6 +125,14 @@ inline dtype_t get_numpy_dtype()
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'f', sizeof(T)};
}

#if defined(_RAFT_HAS_CUDA)
template <typename T, typename std::enable_if_t<std::is_same_v<T, half>, bool> = true>
inline dtype_t get_numpy_dtype()
{
return {RAFT_NUMPY_HOST_ENDIAN_CHAR, 'e', sizeof(T)};
}
#endif

template <typename T,
typename std::enable_if_t<std::is_integral_v<T> && std::is_signed_v<T>, bool> = true>
inline dtype_t get_numpy_dtype()
Expand Down Expand Up @@ -273,7 +285,7 @@ inline dtype_t parse_descr(std::string typestr)

const char endian_chars[] = {
RAFT_NUMPY_LITTLE_ENDIAN_CHAR, RAFT_NUMPY_BIG_ENDIAN_CHAR, RAFT_NUMPY_NO_ENDIAN_CHAR};
const char numtype_chars[] = {'f', 'i', 'u', 'c'};
const char numtype_chars[] = {'f', 'i', 'u', 'c', 'e'};

RAFT_EXPECTS(std::find(std::begin(endian_chars), std::end(endian_chars), byteorder_c) !=
std::end(endian_chars),
Expand Down
238 changes: 124 additions & 114 deletions cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
@@ -1,114 +1,124 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/neighbors/sample_filter_types.hpp> // none_cagra_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class SAMPLE_FILTER_T>
void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
INDEX_T* const topk_indices_ptr,
DISTANCE_T* const topk_distances_ptr,
const DATA_T* const queries_ptr,
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
uint32_t topk,
uint32_t block_size,
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_kernel_selection( \
TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \
extern template void \
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>( \
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset, \
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph, \
INDEX_T* const topk_indices_ptr, \
DISTANCE_T* const topk_distances_ptr, \
const DATA_T* const queries_ptr, \
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cudaStream_t stream);

instantiate_kernel_selection(
32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);

#undef instantiate_kernel_selection
} // namespace multi_cta_search
} // namespace raft::neighbors::cagra::detail
/*
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once

#include <raft/neighbors/sample_filter_types.hpp> // none_cagra_sample_filter
#include <raft/util/raft_explicit.hpp> // RAFT_EXPLICIT

#include <cuda_fp16.h>

namespace raft::neighbors::cagra::detail {
namespace multi_cta_search {

#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATA_T,
class INDEX_T,
class DISTANCE_T,
class SAMPLE_FILTER_T>
void select_and_run(raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset,
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph,
INDEX_T* const topk_indices_ptr,
DISTANCE_T* const topk_distances_ptr,
const DATA_T* const queries_ptr,
const uint32_t num_queries,
const INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
uint32_t topk,
uint32_t block_size,
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cudaStream_t stream) RAFT_EXPLICIT;
#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY

#define instantiate_kernel_selection( \
TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T) \
extern template void \
select_and_run<TEAM_SIZE, MAX_DATASET_DIM, DATA_T, INDEX_T, DISTANCE_T, SAMPLE_FILTER_T>( \
raft::device_matrix_view<const DATA_T, int64_t, layout_stride> dataset, \
raft::device_matrix_view<const INDEX_T, int64_t, row_major> graph, \
INDEX_T* const topk_indices_ptr, \
DISTANCE_T* const topk_distances_ptr, \
const DATA_T* const queries_ptr, \
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cudaStream_t stream);

instantiate_kernel_selection(
32, 1024, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, float, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, half, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, int8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 1024, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
8, 128, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
16, 256, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);
instantiate_kernel_selection(
32, 512, uint8_t, uint32_t, float, raft::neighbors::filtering::none_cagra_sample_filter);

#undef instantiate_kernel_selection
} // namespace multi_cta_search
} // namespace raft::neighbors::cagra::detail
Loading

0 comments on commit 72f48ae

Please sign in to comment.