From 6ec78e9ce291737b247e1712c0dc311f8ae4062c Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Fri, 9 Jun 2023 23:29:37 +0200 Subject: [PATCH] CAGRA pad dataset for 128bit vectorized load (#1505) This PR adds padding to the dataset (if necessary) to make reading any of its rows compatible with 128bit vectorized loads. This change also enables handling arbitrary number of input features (before this PR each row had to be at least 64bit aligned, which constrained the acceptable number of input features). Fixes #1458. With this change, it is sufficient to keep a single "load type" specialization for the search kernels, which shall cut the binary size by half (#1459). Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - tsuki (https://github.com/enp1s0) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1505 --- cpp/include/raft/neighbors/cagra.cuh | 22 +++-- cpp/include/raft/neighbors/cagra_types.hpp | 48 ++++++++--- .../neighbors/detail/cagra/cagra_build.cuh | 4 - .../neighbors/detail/cagra/cagra_search.cuh | 9 +- .../detail/cagra/cagra_serialize.cuh | 20 ++++- .../detail/cagra/compute_distance.hpp | 6 +- .../detail/cagra/search_multi_cta.cuh | 66 +++++---------- .../detail/cagra/search_multi_kernel.cuh | 16 ++-- .../neighbors/detail/cagra/search_plan.cuh | 9 +- .../detail/cagra/search_single_cta.cuh | 83 ++++++------------- cpp/test/neighbors/ann_cagra.cuh | 47 +++++------ 11 files changed, 160 insertions(+), 170 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 9905f2abae..214e963f79 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -237,21 +237,28 @@ index build(raft::resources const& res, const index_params& params, mdspan, row_major, Accessor> dataset) { - size_t degree = params.intermediate_graph_degree; - if (degree >= static_cast(dataset.extent(0))) { + size_t intermediate_degree = params.intermediate_graph_degree; + size_t graph_degree = params.graph_degree; + if (intermediate_degree >= static_cast(dataset.extent(0))) { RAFT_LOG_WARN( "Intermediate graph degree cannot be larger than dataset size, reducing it to %lu", dataset.extent(0)); - degree = dataset.extent(0) - 1; + intermediate_degree = dataset.extent(0) - 1; + } + if (intermediate_degree < graph_degree) { + RAFT_LOG_WARN( + "Graph degree (%lu) cannot be larger than intermediate graph degree (%lu), reducing " + "graph_degree.", + graph_degree, + intermediate_degree); + graph_degree = intermediate_degree; } - RAFT_EXPECTS(degree >= params.graph_degree, - "Intermediate graph degree cannot be smaller than final graph degree"); - auto knn_graph = raft::make_host_matrix(dataset.extent(0), degree); + auto knn_graph = raft::make_host_matrix(dataset.extent(0), intermediate_degree); build_knn_graph(res, dataset, knn_graph.view()); - auto cagra_graph = raft::make_host_matrix(dataset.extent(0), params.graph_degree); + auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); prune(res, knn_graph.view(), cagra_graph.view()); @@ -290,7 +297,6 @@ void search(raft::resources const& res, RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1), "Number of columns in output neighbors and distances matrices must equal k"); - RAFT_EXPECTS(queries.extent(1) == idx.dim(), "Number of query dimensions should equal number of dimensions in the index."); diff --git a/cpp/include/raft/neighbors/cagra_types.hpp b/cpp/include/raft/neighbors/cagra_types.hpp index 87405ae9fb..a88a449a68 100644 --- a/cpp/include/raft/neighbors/cagra_types.hpp +++ b/cpp/include/raft/neighbors/cagra_types.hpp @@ -26,6 +26,7 @@ #include #include #include +#include #include #include @@ -82,8 +83,6 @@ struct search_params : ann::search_params { /** Lower limit of search iterations. */ size_t min_iterations = 0; - /** Bit length for reading the dataset vectors. 0, 64 or 128. Auto selection when 0. */ - size_t load_bit_length = 0; /** Thread block size. 0, 64, 128, 256, 512, 1024. Auto selection when 0. */ size_t thread_block_size = 0; /** Hashmap type. Auto selection when AUTO. */ @@ -113,6 +112,7 @@ static_assert(std::is_aggregate_v); */ template struct index : ann::index { + using AlignDim = raft::Pow2<16 / sizeof(T)>; static_assert(!raft::is_narrowing_v, "IdxT must be able to represent all values of uint32_t"); @@ -124,12 +124,15 @@ struct index : ann::index { } // /** Total length of the index. */ - [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT { return dataset_.extent(0); } + [[nodiscard]] constexpr inline auto size() const noexcept -> IdxT + { + return dataset_view_.extent(0); + } /** Dimensionality of the data. */ [[nodiscard]] constexpr inline auto dim() const noexcept -> uint32_t { - return dataset_.extent(1); + return dataset_view_.extent(1); } /** Graph degree */ [[nodiscard]] constexpr inline auto graph_degree() const noexcept -> uint32_t @@ -138,9 +141,10 @@ struct index : ann::index { } /** Dataset [size, dim] */ - [[nodiscard]] inline auto dataset() const noexcept -> device_matrix_view + [[nodiscard]] inline auto dataset() const noexcept + -> device_matrix_view { - return dataset_.view(); + return dataset_view_; } /** neighborhood graph [size, graph-degree] */ @@ -179,15 +183,36 @@ struct index : ann::index { mdspan, row_major, graph_accessor> knn_graph) : ann::index(), metric_(metric), - dataset_(make_device_matrix(res, dataset.extent(0), dataset.extent(1))), + dataset_( + make_device_matrix(res, dataset.extent(0), AlignDim::roundUp(dataset.extent(1)))), graph_(make_device_matrix(res, knn_graph.extent(0), knn_graph.extent(1))) { RAFT_EXPECTS(dataset.extent(0) == knn_graph.extent(0), "Dataset and knn_graph must have equal number of rows"); - raft::copy(dataset_.data_handle(), - dataset.data_handle(), - dataset.size(), - resource::get_cuda_stream(res)); + if (dataset_.extent(1) == dataset.extent(1)) { + raft::copy(dataset_.data_handle(), + dataset.data_handle(), + dataset.size(), + resource::get_cuda_stream(res)); + } else { + // copy with padding + RAFT_CUDA_TRY(cudaMemsetAsync( + dataset_.data_handle(), 0, dataset_.size() * sizeof(T), resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(dataset_.data_handle(), + sizeof(T) * dataset_.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.extent(1), + sizeof(T) * dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + } + dataset_view_ = make_device_strided_matrix_view( + dataset_.data_handle(), dataset_.extent(0), dataset.extent(1), dataset_.extent(1)); + RAFT_LOG_DEBUG("CAGRA dataset strided matrix view %zux%zu, stride %zu", + static_cast(dataset_view_.extent(0)), + static_cast(dataset_view_.extent(1)), + static_cast(dataset_view_.stride(0))); raft::copy(graph_.data_handle(), knn_graph.data_handle(), knn_graph.size(), @@ -199,6 +224,7 @@ struct index : ann::index { raft::distance::DistanceType metric_; raft::device_matrix dataset_; raft::device_matrix graph_; + raft::device_matrix_view dataset_view_; }; /** @} */ diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh index 693ab9029d..5c196471aa 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh @@ -46,10 +46,6 @@ void build_knn_graph(raft::resources const& res, std::optional build_params = std::nullopt, std::optional search_params = std::nullopt) { - RAFT_EXPECTS( - dataset.extent(1) * sizeof(DataT) % 8 == 0, - "Dataset rows are expected to have at least 8 bytes alignment. Try padding feature dims."); - RAFT_EXPECTS(!build_params || build_params->metric == distance::DistanceType::L2Expanded, "Currently only L2Expanded metric is supported"); diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh index d3b24dc861..7b35af4417 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_search.cuh @@ -27,8 +27,6 @@ #include #include "factory.cuh" -#include "search_multi_cta.cuh" -#include "search_multi_kernel.cuh" #include "search_plan.cuh" #include "search_single_cta.cuh" @@ -92,8 +90,11 @@ void search_main(raft::resources const& res, : nullptr; uint32_t* _num_executed_iterations = nullptr; - auto dataset_internal = raft::make_device_matrix_view( - index.dataset().data_handle(), index.dataset().extent(0), index.dataset().extent(1)); + auto dataset_internal = make_device_strided_matrix_view( + index.dataset().data_handle(), + index.dataset().extent(0), + index.dataset().extent(1), + index.dataset().stride(0)); auto graph_internal = raft::make_device_matrix_view( reinterpret_cast(index.graph().data_handle()), diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 04d0bb350f..7632318b88 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -25,7 +25,7 @@ namespace raft::neighbors::experimental::cagra::detail { // Serialization version 1. -constexpr int serialization_version = 1; +constexpr int serialization_version = 2; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -36,7 +36,8 @@ struct check_index_layout { "paste in the new size and consider updating the serialization logic"); }; -template struct check_index_layout), 136>; +constexpr size_t expected_size = 176; +template struct check_index_layout), expected_size>; /** * Save the index to file. @@ -59,7 +60,19 @@ void serialize(raft::resources const& res, std::ostream& os, const index(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, host_dataset.view()); serialize_mdspan(res, os, index_.graph()); } @@ -100,7 +113,6 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto dataset = raft::make_host_matrix(n_rows, dim); auto graph = raft::make_host_matrix(n_rows, graph_degree); - deserialize_mdspan(res, is, dataset.view()); deserialize_mdspan(res, is, graph.view()); diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index fd66735cf6..f67e110fc6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -56,6 +56,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] const std::size_t dataset_dim, const std::size_t dataset_size, + const std::size_t dataset_ld, const std::size_t num_pickup, const unsigned num_distilation, const uint64_t rand_xor_mask, @@ -93,7 +94,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( for (uint32_t e = 0; e < nelem; e++) { const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; if (k >= dataset_dim) break; - dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_dim * seed_index)))[0]; + dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * seed_index)))[0]; } #pragma unroll for (uint32_t e = 0; e < nelem; e++) { @@ -146,6 +147,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in // [dataset_dim, dataset_size] const DATA_T* const dataset_ptr, const std::size_t dataset_dim, + const std::size_t dataset_ld, // [knn_k, dataset_size] const INDEX_T* const knn_graph, const std::uint32_t knn_k, @@ -215,7 +217,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in for (unsigned e = 0; e < nelem; e++) { const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; if (k >= dataset_dim) break; - dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_dim * child_id)))[0]; + dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * child_id)))[0]; } #pragma unroll for (unsigned e = 0; e < nelem; e++) { diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index f9a0fef2fe..8ab6b19b98 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -138,6 +138,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] const size_t dataset_dim, const size_t dataset_size, + const size_t dataset_ld, const DATA_T* const queries_ptr, // [num_queries, dataset_dim] const INDEX_T* const knn_graph, // [dataset_size, graph_degree] const uint32_t graph_degree, @@ -231,6 +232,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( dataset_ptr, dataset_dim, dataset_size, + dataset_ld, result_buffer_size, num_distilation, rand_xor_mask, @@ -278,6 +280,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( query_buffer, dataset_ptr, dataset_dim, + dataset_ld, knn_graph, graph_degree, local_visited_hashmap_ptr, @@ -326,38 +329,31 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( #endif } -#define SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, LOAD_T) \ - kernel = search_kernel; - -#define SET_MC_KERNEL_2(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS) \ - if (load_bit_length == 128) { \ - SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, device::LOAD_128BIT_T) \ - } else if (load_bit_length == 64) { \ - SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS, device::LOAD_64BIT_T) \ - } +#define SET_MC_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ELEMENTS) \ + kernel = search_kernel; #define SET_MC_KERNEL_1(MAX_ELEMENTS) \ /* if ( block_size == 32 ) { \ - SET_MC_KERNEL_2( 32, 32, MAX_ELEMENTS ) \ + SET_MC_KERNEL_3( 32, 32, MAX_ELEMENTS ) \ } else */ \ if (block_size == 64) { \ - SET_MC_KERNEL_2(64, 16, MAX_ELEMENTS) \ + SET_MC_KERNEL_3(64, 16, MAX_ELEMENTS) \ } else if (block_size == 128) { \ - SET_MC_KERNEL_2(128, 8, MAX_ELEMENTS) \ + SET_MC_KERNEL_3(128, 8, MAX_ELEMENTS) \ } else if (block_size == 256) { \ - SET_MC_KERNEL_2(256, 4, MAX_ELEMENTS) \ + SET_MC_KERNEL_3(256, 4, MAX_ELEMENTS) \ } else if (block_size == 512) { \ - SET_MC_KERNEL_2(512, 2, MAX_ELEMENTS) \ + SET_MC_KERNEL_3(512, 2, MAX_ELEMENTS) \ } else { \ - SET_MC_KERNEL_2(1024, 1, MAX_ELEMENTS) \ + SET_MC_KERNEL_3(1024, 1, MAX_ELEMENTS) \ } #define SET_MC_KERNEL \ @@ -366,6 +362,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ void search_kernel( const DATA_T* const dataset_ptr, \ const size_t dataset_dim, \ const size_t dataset_size, \ + const size_t dataset_ld, \ const DATA_T* const queries_ptr, \ const INDEX_T* const knn_graph, \ const uint32_t graph_degree, \ @@ -431,7 +428,6 @@ struct search : public search_plan_impl { using search_plan_impl::num_parents; using search_plan_impl::min_iterations; using search_plan_impl::max_iterations; - using search_plan_impl::load_bit_length; using search_plan_impl::thread_block_size; using search_plan_impl::hashmap_mode; using search_plan_impl::hashmap_min_bitlen; @@ -453,7 +449,6 @@ struct search : public search_plan_impl { using search_plan_impl::result_buffer_size; using search_plan_impl::smem_size; - using search_plan_impl::load_bit_lenght; using search_plan_impl::hashmap; using search_plan_impl::num_executed_iterations; @@ -533,24 +528,6 @@ struct search : public search_plan_impl { max_block_size); thread_block_size = block_size; - // - // Determine load bit length - // - const uint32_t total_bit_length = dim * sizeof(DATA_T) * 8; - if (load_bit_length == 0) { - load_bit_length = 128; - while (total_bit_length % load_bit_length) { - load_bit_length /= 2; - } - } - RAFT_LOG_DEBUG("# load_bit_length: %u (%u loads per vector)", - load_bit_length, - total_bit_length / load_bit_length); - RAFT_EXPECTS(total_bit_length % load_bit_length == 0, - "load_bit_length must be a divisor of dim*sizeof(data_t)*8=%u", - total_bit_length); - RAFT_EXPECTS(load_bit_length >= 64, "load_bit_lenght cannot be less than 64"); - // // Allocate memory for intermediate buffer and workspace. // @@ -569,7 +546,7 @@ struct search : public search_plan_impl { ~search() {} void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, // [num_queries, topk] DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] @@ -602,6 +579,7 @@ struct search : public search_plan_impl { dataset.data_handle(), dataset.extent(1), dataset.extent(0), + dataset.stride(0), queries_ptr, graph.data_handle(), graph.extent(1), diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index 8fbd5d8f03..033022aea1 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -93,6 +93,7 @@ __global__ void random_pickup_kernel( const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] const std::size_t dataset_dim, const std::size_t dataset_size, + const std::size_t dataset_ld, const DATA_T* const queries_ptr, // [num_queries, dataset_dim] const std::size_t num_pickup, const unsigned num_distilation, @@ -125,7 +126,7 @@ __global__ void random_pickup_kernel( } device::fragment random_data_frag; device::load_vector_sync( - random_data_frag, dataset_ptr + (dataset_dim * seed_index), dataset_dim); + random_data_frag, dataset_ptr + (dataset_ld * seed_index), dataset_dim); // Compute the norm of two data const auto norm2 = device::norm2( @@ -163,6 +164,7 @@ template >>(dataset_ptr, dataset_dim, dataset_size, + dataset_ld, queries_ptr, num_pickup, num_distilation, @@ -310,6 +313,7 @@ __global__ void compute_distance_to_child_nodes_kernel( const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, const std::uint32_t dataset_size, + const std::uint32_t dataset_ld, const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, const DATA_T* query_ptr, // [num_queries, data_dim] @@ -338,7 +342,7 @@ __global__ void compute_distance_to_child_nodes_kernel( if (hashmap::insert( visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { device::fragment frag_target; - device::load_vector_sync(frag_target, dataset_ptr + (data_dim * child_id), data_dim); + device::load_vector_sync(frag_target, dataset_ptr + (dataset_ld * child_id), data_dim); device::fragment frag_query; device::load_vector_sync(frag_query, query_ptr + blockIdx.y * data_dim, data_dim); @@ -370,6 +374,7 @@ void compute_distance_to_child_nodes( const DATA_T* const dataset_ptr, // [dataset_size, data_dim] const std::uint32_t data_dim, const std::uint32_t dataset_size, + const std::uint32_t dataset_ld, const INDEX_T* const neighbor_graph_ptr, // [dataset_size, graph_degree] const std::uint32_t graph_degree, const DATA_T* query_ptr, // [num_queries, data_dim] @@ -391,6 +396,7 @@ void compute_distance_to_child_nodes( dataset_ptr, data_dim, dataset_size, + dataset_ld, neighbor_graph_ptr, graph_degree, query_ptr, @@ -511,7 +517,6 @@ struct search : search_plan_impl { using search_plan_impl::num_parents; using search_plan_impl::min_iterations; using search_plan_impl::max_iterations; - using search_plan_impl::load_bit_length; using search_plan_impl::thread_block_size; using search_plan_impl::hashmap_mode; using search_plan_impl::hashmap_min_bitlen; @@ -533,7 +538,6 @@ struct search : search_plan_impl { using search_plan_impl::result_buffer_size; using search_plan_impl::smem_size; - using search_plan_impl::load_bit_lenght; using search_plan_impl::hashmap; using search_plan_impl::num_executed_iterations; @@ -590,7 +594,7 @@ struct search : search_plan_impl { ~search() {} void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const topk_indices_ptr, // [num_queries, topk] DISTANCE_T* const topk_distances_ptr, // [num_queries, topk] @@ -613,6 +617,7 @@ struct search : search_plan_impl { dataset.data_handle(), dataset.extent(1), dataset.extent(0), + dataset.stride(0), queries_ptr, num_queries, result_buffer_size, @@ -683,6 +688,7 @@ struct search : search_plan_impl { dataset.data_handle(), dataset.extent(1), dataset.extent(0), + dataset.stride(0), graph.data_handle(), graph.extent(1), queries_ptr, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 3bed100a70..3b09902639 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -77,7 +77,6 @@ struct search_plan_impl : public search_plan_impl_base { uint32_t result_buffer_size; uint32_t smem_size; - uint32_t load_bit_lenght; uint32_t topk; uint32_t num_seeds; @@ -107,7 +106,7 @@ struct search_plan_impl : public search_plan_impl_base { virtual ~search_plan_impl() {} virtual void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const result_indices_ptr, // [num_queries, topk] DISTANCE_T* const result_distances_ptr, // [num_queries, topk] @@ -286,14 +285,10 @@ struct search_plan_impl : public search_plan_impl_base { error_message += "`team_size` must be 0, 4, 8, 16 or 32. " + std::to_string(team_size) + " has been given."; } - if (load_bit_length != 0 && load_bit_length != 64 && load_bit_length != 128) { - error_message += "`load_bit_length` must be 0, 64 or 128. " + - std::to_string(load_bit_length) + " has been given."; - } if (thread_block_size != 0 && thread_block_size != 64 && thread_block_size != 128 && thread_block_size != 256 && thread_block_size != 512 && thread_block_size != 1024) { error_message += "`thread_block_size` must be 0, 64, 128, 256 or 512. " + - std::to_string(load_bit_length) + " has been given."; + std::to_string(thread_block_size) + " has been given."; } if (hashmap_min_bitlen > 20) { error_message += "`hashmap_min_bitlen` must be equal to or smaller than 20. " + diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index 9400a16c36..219a1dd717 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -540,6 +540,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ const DATA_T* const dataset_ptr, // [dataset_size, dataset_dim] const std::size_t dataset_dim, const std::size_t dataset_size, + const std::size_t dataset_ld, // stride of dataset const DATA_T* const queries_ptr, // [num_queries, dataset_dim] const INDEX_T* const knn_graph, // [dataset_size, graph_degree] const std::uint32_t graph_degree, @@ -634,6 +635,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ dataset_ptr, dataset_dim, dataset_size, + dataset_ld, result_buffer_size, num_distilation, rand_xor_mask, @@ -758,6 +760,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ query_buffer, dataset_ptr, dataset_dim, + dataset_ld, knn_graph, graph_degree, local_visited_hashmap_ptr, @@ -808,60 +811,42 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ #endif } -#define SET_KERNEL_3( \ - BLOCK_SIZE, BLOCK_COUNT, MAX_ITOPK, MAX_CANDIDATES, TOPK_BY_BITONIC_SORT, LOAD_T) \ - kernel = search_kernel; - -#define SET_KERNEL_2(BLOCK_SIZE, BLOCK_COUNT, MAX_ITOPK, MAX_CANDIDATES, TOPK_BY_BITONIC_SORT) \ - if (load_bit_length == 128) { \ - SET_KERNEL_3(BLOCK_SIZE, \ - BLOCK_COUNT, \ - MAX_ITOPK, \ - MAX_CANDIDATES, \ - TOPK_BY_BITONIC_SORT, \ - device::LOAD_128BIT_T) \ - } else if (load_bit_length == 64) { \ - SET_KERNEL_3(BLOCK_SIZE, \ - BLOCK_COUNT, \ - MAX_ITOPK, \ - MAX_CANDIDATES, \ - TOPK_BY_BITONIC_SORT, \ - device::LOAD_64BIT_T) \ - } +#define SET_KERNEL_3(BLOCK_SIZE, BLOCK_COUNT, MAX_ITOPK, MAX_CANDIDATES, TOPK_BY_BITONIC_SORT) \ + kernel = search_kernel; #define SET_KERNEL_1B(MAX_ITOPK, MAX_CANDIDATES) \ /* if ( block_size == 32 ) { \ SET_KERNEL_2( 32, 20, MAX_ITOPK, MAX_CANDIDATES, 1 ) \ } else */ \ if (block_size == 64) { \ - SET_KERNEL_2(64, 16 /*20*/, MAX_ITOPK, MAX_CANDIDATES, 1) \ + SET_KERNEL_3(64, 16 /*20*/, MAX_ITOPK, MAX_CANDIDATES, 1) \ } else if (block_size == 128) { \ - SET_KERNEL_2(128, 8, MAX_ITOPK, MAX_CANDIDATES, 1) \ + SET_KERNEL_3(128, 8, MAX_ITOPK, MAX_CANDIDATES, 1) \ } else if (block_size == 256) { \ - SET_KERNEL_2(256, 4, MAX_ITOPK, MAX_CANDIDATES, 1) \ + SET_KERNEL_3(256, 4, MAX_ITOPK, MAX_CANDIDATES, 1) \ } else if (block_size == 512) { \ - SET_KERNEL_2(512, 2, MAX_ITOPK, MAX_CANDIDATES, 1) \ + SET_KERNEL_3(512, 2, MAX_ITOPK, MAX_CANDIDATES, 1) \ } else { \ - SET_KERNEL_2(1024, 1, MAX_ITOPK, MAX_CANDIDATES, 1) \ + SET_KERNEL_3(1024, 1, MAX_ITOPK, MAX_CANDIDATES, 1) \ } #define SET_KERNEL_1R(MAX_ITOPK, MAX_CANDIDATES) \ if (block_size == 256) { \ - SET_KERNEL_2(256, 4, MAX_ITOPK, MAX_CANDIDATES, 0) \ + SET_KERNEL_3(256, 4, MAX_ITOPK, MAX_CANDIDATES, 0) \ } else if (block_size == 512) { \ - SET_KERNEL_2(512, 2, MAX_ITOPK, MAX_CANDIDATES, 0) \ + SET_KERNEL_3(512, 2, MAX_ITOPK, MAX_CANDIDATES, 0) \ } else { \ - SET_KERNEL_2(1024, 1, MAX_ITOPK, MAX_CANDIDATES, 0) \ + SET_KERNEL_3(1024, 1, MAX_ITOPK, MAX_CANDIDATES, 0) \ } #define SET_KERNEL \ @@ -871,6 +856,7 @@ __launch_bounds__(BLOCK_SIZE, BLOCK_COUNT) __global__ const DATA_T* const dataset_ptr, \ const std::size_t dataset_dim, \ const std::size_t dataset_size, \ + const std::size_t dataset_ld, \ const DATA_T* const queries_ptr, \ const INDEX_T* const knn_graph, \ const std::uint32_t graph_degree, \ @@ -943,7 +929,6 @@ struct search : search_plan_impl { using search_plan_impl::num_parents; using search_plan_impl::min_iterations; using search_plan_impl::max_iterations; - using search_plan_impl::load_bit_length; using search_plan_impl::thread_block_size; using search_plan_impl::hashmap_mode; using search_plan_impl::hashmap_min_bitlen; @@ -965,7 +950,6 @@ struct search : search_plan_impl { using search_plan_impl::result_buffer_size; using search_plan_impl::smem_size; - using search_plan_impl::load_bit_lenght; using search_plan_impl::hashmap; using search_plan_impl::num_executed_iterations; @@ -1066,22 +1050,6 @@ struct search : search_plan_impl { max_block_size); thread_block_size = block_size; - // Determine load bit length - const uint32_t total_bit_length = dim * sizeof(DATA_T) * 8; - if (load_bit_length == 0) { - load_bit_length = 128; - while (total_bit_length % load_bit_length) { - load_bit_length /= 2; - } - } - RAFT_LOG_DEBUG("# load_bit_length: %u (%u loads per vector)", - load_bit_length, - total_bit_length / load_bit_length); - RAFT_EXPECTS(total_bit_length % load_bit_length == 0, - "load_bit_length must be a divisor of dim*sizeof(data_t)*8=%u", - total_bit_length); - RAFT_EXPECTS(load_bit_length >= 64, "load_bit_lenght cannot be less than 64"); - if (num_itopk_candidates <= 256) { RAFT_LOG_DEBUG("# bitonic-sort based topk routine is used"); } else { @@ -1129,7 +1097,7 @@ struct search : search_plan_impl { } void operator()(raft::resources const& res, - raft::device_matrix_view dataset, + raft::device_matrix_view dataset, raft::device_matrix_view graph, INDEX_T* const result_indices_ptr, // [num_queries, topk] DISTANCE_T* const result_distances_ptr, // [num_queries, topk] @@ -1154,6 +1122,7 @@ struct search : search_plan_impl { dataset.data_handle(), dataset.extent(1), dataset.extent(0), + dataset.stride(0), queries_ptr, graph.data_handle(), graph.extent(1), diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 63c8114de6..d3bd5ba31d 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -166,9 +166,8 @@ class AnnCagraTest : public ::testing::TestWithParam { protected: void testCagra() { - if (ps.dim * sizeof(DataT) % 8 != 0) { - GTEST_SKIP() - << "CAGRA requires the input data rows to be aligned at least to 8 bytes for now."; + if (ps.algo == search_algo::MULTI_CTA && ps.max_queries > 1) { + GTEST_SKIP() << "Skipping test due to issue #1575"; } size_t queries_size = ps.n_queries * ps.k; std::vector indices_Cagra(queries_size); @@ -234,12 +233,11 @@ class AnnCagraTest : public ::testing::TestWithParam { cagra::search( handle_, search_params, index, search_queries_view, indices_out_view, dists_out_view); - update_host(distances_Cagra.data(), distances_dev.data(), queries_size, stream_); update_host(indices_Cagra.data(), indices_dev.data(), queries_size, stream_); resource::sync_stream(handle_); } - // for (int i = 0; i < ps.n_queries; i++) { + // for (int i = 0; i < min(ps.n_queries, 10); i++) { // // std::cout << "query " << i << std::end; // print_vector("T", indices_naive.data() + i * ps.k, ps.k, std::cout); // print_vector("C", indices_Cagra.data() + i * ps.k, ps.k, std::cout); @@ -247,7 +245,7 @@ class AnnCagraTest : public ::testing::TestWithParam { // print_vector("C", distances_Cagra.data() + i * ps.k, ps.k, std::cout); // } double min_recall = ps.min_recall; - ASSERT_TRUE(eval_neighbours(indices_naive, + EXPECT_TRUE(eval_neighbours(indices_naive, indices_Cagra, distances_naive, distances_Cagra, @@ -255,7 +253,7 @@ class AnnCagraTest : public ::testing::TestWithParam { ps.k, 0.001, min_recall)); - ASSERT_TRUE(eval_distances(handle_, + EXPECT_TRUE(eval_distances(handle_, database.data(), search_queries.data(), indices_dev.data(), @@ -374,33 +372,34 @@ class AnnCagraSortTest : public ::testing::TestWithParam { inline std::vector generate_inputs() { // Todo(tfeher): MULTI_CTA tests a bug, consider disabling that mode. + // TODO(tfeher): test MULTI_CTA kernel with num_Parents>1 to allow multiple CTA per queries std::vector inputs = raft::util::itertools::product( {100}, {1000}, - {8}, - {1, 16, 33}, // k - {search_algo::SINGLE_CTA, search_algo::MULTI_KERNEL}, + {1, 8, 17}, + {1, 16}, // k + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {1, 10, 100}, // query size {0}, - {64}, + {256}, {1}, {raft::distance::DistanceType::L2Expanded}, {false}, {0.995}); - auto inputs2 = - raft::util::itertools::product({100}, - {1000}, - {8, 64, 128, 192, 256, 512, 1024}, // dim - {16}, - {search_algo::AUTO}, - {10}, - {0}, - {64}, - {1}, - {raft::distance::DistanceType::L2Expanded}, - {false}, - {0.995}); + auto inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim + {16}, // k + {search_algo::AUTO}, + {10}, + {0}, + {64}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); inputs2 = raft::util::itertools::product({100},