From 96ca91a218c95fd5da8c4ea8e4c96d7cedb230a0 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 18 Jul 2023 15:42:01 +0200 Subject: [PATCH 1/3] neighbors: Optimize host-side refine --- cpp/CMakeLists.txt | 7 + cpp/include/raft/neighbors/detail/refine.cuh | 230 +----------------- .../raft/neighbors/detail/refine_common.hpp | 57 +++++ .../raft/neighbors/detail/refine_device.cuh | 108 ++++++++ .../raft/neighbors/detail/refine_host-ext.hpp | 55 +++++ .../raft/neighbors/detail/refine_host-inl.hpp | 134 ++++++++++ .../raft/neighbors/detail/refine_host.hpp | 24 ++ .../detail/refine_host_float_float.cpp | 29 +++ .../detail/refine_host_int8_t_float.cpp | 29 +++ .../detail/refine_host_uint8_t_float.cpp | 30 +++ 10 files changed, 476 insertions(+), 227 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/refine_common.hpp create mode 100644 cpp/include/raft/neighbors/detail/refine_device.cuh create mode 100644 cpp/include/raft/neighbors/detail/refine_host-ext.hpp create mode 100644 cpp/include/raft/neighbors/detail/refine_host-inl.hpp create mode 100644 cpp/include/raft/neighbors/detail/refine_host.hpp create mode 100644 cpp/src/neighbors/detail/refine_host_float_float.cpp create mode 100644 cpp/src/neighbors/detail/refine_host_int8_t_float.cpp create mode 100644 cpp/src/neighbors/detail/refine_host_uint8_t_float.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 6fa1b5830e..353e423927 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -318,6 +318,9 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu + src/neighbors/detail/refine_host_float_float.cpp + src/neighbors/detail/refine_host_int8_t_float.cpp + src/neighbors/detail/refine_host_uint8_t_float.cpp src/neighbors/detail/selection_faiss_int32_t_float.cu src/neighbors/detail/selection_faiss_int_double.cu src/neighbors/detail/selection_faiss_long_float.cu @@ -434,6 +437,10 @@ if(RAFT_COMPILE_LIBRARY) # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") + # Optimization: extra compile flags for individual compilation units + file(GLOB REFINE_HOST_SRC_FILES "src/neighbors/detail/refine_host_*.cpp") + set_property(SOURCE ${REFINE_HOST_SRC_FILES} PROPERTY COMPILE_FLAGS "-ftree-vectorize") + endif() if(TARGET raft_lib AND (NOT TARGET raft::raft_lib)) diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 251f725361..170f973984 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,231 +13,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include - -namespace raft::neighbors::detail { - -/** Checks whether the input data extents are compatible. */ -template -void check_input(extents_t dataset, - extents_t queries, - extents_t candidates, - extents_t indices, - extents_t distances, - distance::DistanceType metric) -{ - auto n_queries = queries.extent(0); - auto k = distances.extent(1); - - RAFT_EXPECTS(k <= raft::matrix::detail::select::warpsort::kMaxCapacity, - "k must be lest than topk::kMaxCapacity (%d).", - raft::matrix::detail::select::warpsort::kMaxCapacity); - - RAFT_EXPECTS(indices.extent(0) == n_queries && distances.extent(0) == n_queries && - candidates.extent(0) == n_queries, - "Number of rows in output indices, distances and candidates matrices must be equal" - " with the number of rows in search matrix. Expected %d, got %d, %d, and %d", - static_cast(n_queries), - static_cast(indices.extent(0)), - static_cast(distances.extent(0)), - static_cast(candidates.extent(0))); - - RAFT_EXPECTS(indices.extent(1) == k, - "Number of columns in output indices and distances matrices must be equal to k"); - - RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), - "Number of columns must be equal for dataset and queries"); - - RAFT_EXPECTS(candidates.extent(1) >= k, - "Number of neighbor candidates must not be smaller than k (%d vs %d)", - static_cast(candidates.extent(1)), - static_cast(k)); -} - -/** - * See raft::neighbors::refine for docs. - */ -template -void refine_device(raft::resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - matrix_idx n_candidates = neighbor_candidates.extent(1); - matrix_idx n_queries = queries.extent(0); - matrix_idx dim = dataset.extent(1); - uint32_t k = static_cast(indices.extent(1)); - - common::nvtx::range fun_scope( - "neighbors::refine(%zu, %u)", size_t(n_queries), uint32_t(n_candidates)); - - check_input(dataset.extents(), - queries.extents(), - neighbor_candidates.extents(), - indices.extents(), - distances.extents(), - metric); - - // The refinement search can be mapped to an IVF flat search: - // - We consider that the candidate vectors form a cluster, separately for each query. - // - In other words, the n_queries * n_candidates vectors form n_queries clusters, each with - // n_candidates elements. - // - We consider that the coarse level search is already performed and assigned a single cluster - // to search for each query (the cluster formed from the corresponding candidates). - // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. - rmm::device_uvector fake_coarse_idx(n_queries, resource::get_cuda_stream(handle)); - - thrust::sequence(resource::get_thrust_policy(handle), - fake_coarse_idx.data(), - fake_coarse_idx.data() + n_queries); - - raft::neighbors::ivf_flat::index refinement_index( - handle, metric, n_queries, false, true, dim); - - raft::neighbors::ivf_flat::detail::fill_refinement_index(handle, - &refinement_index, - dataset.data_handle(), - neighbor_candidates.data_handle(), - n_queries, - n_candidates); - uint32_t grid_dim_x = 1; - raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< - data_t, - typename raft::spatial::knn::detail::utils::config::value_t, - idx_t>(refinement_index, - queries.data_handle(), - fake_coarse_idx.data(), - static_cast(n_queries), - 0, - refinement_index.metric(), - 1, - k, - raft::distance::is_min_close(metric), - raft::neighbors::filtering::none_ivf_sample_filter(), - indices.data_handle(), - distances.data_handle(), - grid_dim_x, - resource::get_cuda_stream(handle)); -} - -/** Helper structure for naive CPU implementation of refine. */ -typedef struct { - uint64_t id; - float distance; -} struct_for_refinement; - -inline int _postprocessing_qsort_compare(const void* v1, const void* v2) -{ - // sort in ascending order - if (((struct_for_refinement*)v1)->distance > ((struct_for_refinement*)v2)->distance) { - return 1; - } else if (((struct_for_refinement*)v1)->distance < ((struct_for_refinement*)v2)->distance) { - return -1; - } else { - return 0; - } -} - -/** - * Naive CPU implementation of refine operation - * - * All pointers are expected to be accessible on the host. - */ -template -void refine_host(raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded) -{ - check_input(dataset.extents(), - queries.extents(), - neighbor_candidates.extents(), - indices.extents(), - distances.extents(), - metric); - - switch (metric) { - case raft::distance::DistanceType::L2Expanded: break; - case raft::distance::DistanceType::InnerProduct: break; - default: throw raft::logic_error("Unsopported metric"); - } - - size_t numDataset = dataset.extent(0); - size_t numQueries = queries.extent(0); - size_t dimDataset = dataset.extent(1); - const data_t* dataset_ptr = dataset.data_handle(); - const data_t* queries_ptr = queries.data_handle(); - const idx_t* neighbors = neighbor_candidates.data_handle(); - idx_t topK = neighbor_candidates.extent(1); - idx_t refinedTopK = indices.extent(1); - idx_t* refinedNeighbors = indices.data_handle(); - distance_t* refinedDistances = distances.data_handle(); - - common::nvtx::range fun_scope( - "neighbors::refine_host(%zu, %u)", size_t(numQueries), uint32_t(topK)); - -#pragma omp parallel - { - struct_for_refinement* sfr = - (struct_for_refinement*)malloc(sizeof(struct_for_refinement) * topK); - for (size_t i = omp_get_thread_num(); i < numQueries; i += omp_get_num_threads()) { - // compute distance with original dataset vectors - const data_t* cur_query = queries_ptr + ((uint64_t)dimDataset * i); - for (size_t j = 0; j < (size_t)topK; j++) { - idx_t id = neighbors[j + (topK * i)]; - const data_t* cur_dataset = dataset_ptr + ((uint64_t)dimDataset * id); - float distance = 0.0; - for (size_t k = 0; k < (size_t)dimDataset; k++) { - float val_q = (float)(cur_query[k]); - float val_d = (float)(cur_dataset[k]); - if (metric == raft::distance::DistanceType::InnerProduct) { - distance += -val_q * val_d; // Negate because we sort in ascending order. - } else { - distance += (val_q - val_d) * (val_q - val_d); - } - } - sfr[j].id = id; - sfr[j].distance = distance; - } - - qsort(sfr, topK, sizeof(struct_for_refinement), _postprocessing_qsort_compare); - - for (size_t j = 0; j < (size_t)refinedTopK; j++) { - refinedNeighbors[j + (refinedTopK * i)] = sfr[j].id; - if (refinedDistances == NULL) continue; - if (metric == raft::distance::DistanceType::InnerProduct) { - refinedDistances[j + (refinedTopK * i)] = -sfr[j].distance; - } else { - refinedDistances[j + (refinedTopK * i)] = sfr[j].distance; - } - } - } - free(sfr); - } -} - -} // namespace raft::neighbors::detail +#include "refine_device.cuh" +#include "refine_host.hpp" diff --git a/cpp/include/raft/neighbors/detail/refine_common.hpp b/cpp/include/raft/neighbors/detail/refine_common.hpp new file mode 100644 index 0000000000..bfd3341ee9 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine_common.hpp @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace raft::neighbors::detail { + +/** Checks whether the input data extents are compatible. */ +template +void refine_check_input(ExtentsT dataset, + ExtentsT queries, + ExtentsT candidates, + ExtentsT indices, + ExtentsT distances, + distance::DistanceType metric) +{ + auto n_queries = queries.extent(0); + auto k = distances.extent(1); + + RAFT_EXPECTS(indices.extent(0) == n_queries && distances.extent(0) == n_queries && + candidates.extent(0) == n_queries, + "Number of rows in output indices, distances and candidates matrices must be equal" + " with the number of rows in search matrix. Expected %d, got %d, %d, and %d", + static_cast(n_queries), + static_cast(indices.extent(0)), + static_cast(distances.extent(0)), + static_cast(candidates.extent(0))); + + RAFT_EXPECTS(indices.extent(1) == k, + "Number of columns in output indices and distances matrices must be equal to k"); + + RAFT_EXPECTS(queries.extent(1) == dataset.extent(1), + "Number of columns must be equal for dataset and queries"); + + RAFT_EXPECTS(candidates.extent(1) >= k, + "Number of neighbor candidates must not be smaller than k (%d vs %d)", + static_cast(candidates.extent(1)), + static_cast(k)); +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh new file mode 100644 index 0000000000..ac6d0fa2d6 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace raft::neighbors::detail { + +/** + * See raft::neighbors::refine for docs. + */ +template +void refine_device(raft::resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + matrix_idx n_candidates = neighbor_candidates.extent(1); + matrix_idx n_queries = queries.extent(0); + matrix_idx dim = dataset.extent(1); + uint32_t k = static_cast(indices.extent(1)); + + RAFT_EXPECTS(k <= raft::matrix::detail::select::warpsort::kMaxCapacity, + "k must be lest than topk::kMaxCapacity (%d).", + raft::matrix::detail::select::warpsort::kMaxCapacity); + + common::nvtx::range fun_scope( + "neighbors::refine(%zu, %u)", size_t(n_queries), uint32_t(n_candidates)); + + refine_check_input(dataset.extents(), + queries.extents(), + neighbor_candidates.extents(), + indices.extents(), + distances.extents(), + metric); + + // The refinement search can be mapped to an IVF flat search: + // - We consider that the candidate vectors form a cluster, separately for each query. + // - In other words, the n_queries * n_candidates vectors form n_queries clusters, each with + // n_candidates elements. + // - We consider that the coarse level search is already performed and assigned a single cluster + // to search for each query (the cluster formed from the corresponding candidates). + // - We run IVF flat search with n_probes=1 to select the best k elements of the candidates. + rmm::device_uvector fake_coarse_idx(n_queries, resource::get_cuda_stream(handle)); + + thrust::sequence(resource::get_thrust_policy(handle), + fake_coarse_idx.data(), + fake_coarse_idx.data() + n_queries); + + raft::neighbors::ivf_flat::index refinement_index( + handle, metric, n_queries, false, true, dim); + + raft::neighbors::ivf_flat::detail::fill_refinement_index(handle, + &refinement_index, + dataset.data_handle(), + neighbor_candidates.data_handle(), + n_queries, + n_candidates); + uint32_t grid_dim_x = 1; + raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< + data_t, + typename raft::spatial::knn::detail::utils::config::value_t, + idx_t>(refinement_index, + queries.data_handle(), + fake_coarse_idx.data(), + static_cast(n_queries), + 0, + refinement_index.metric(), + 1, + k, + raft::distance::is_min_close(metric), + raft::neighbors::filtering::none_ivf_sample_filter(), + indices.data_handle(), + distances.data_handle(), + grid_dim_x, + resource::get_cuda_stream(handle)); +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine_host-ext.hpp b/cpp/include/raft/neighbors/detail/refine_host-ext.hpp new file mode 100644 index 0000000000..3ce2dc3eb5 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine_host-ext.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::host_matrix_view +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::detail { + +template +[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host( + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) RAFT_EXPLICIT; + +} + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_refine(IdxT, DataT, DistanceT, ExtentsT) \ + extern template void raft::neighbors::detail::refine_host( \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/include/raft/neighbors/detail/refine_host-inl.hpp b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp new file mode 100644 index 0000000000..cfedaa38d3 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine_host-inl.hpp @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include + +#include +#include + +namespace raft::neighbors::detail { + +template +[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host_impl( + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances) +{ + size_t n_queries = queries.extent(0); + size_t dim = dataset.extent(1); + size_t orig_k = neighbor_candidates.extent(1); + size_t refined_k = indices.extent(1); + + common::nvtx::range fun_scope( + "neighbors::refine_host(%zu, %zu -> %zu)", n_queries, orig_k, refined_k); + + auto suggested_n_threads = std::max(1, std::min(omp_get_num_procs(), omp_get_max_threads())); + if (size_t(suggested_n_threads) > n_queries) { suggested_n_threads = n_queries; } + +#pragma omp parallel num_threads(suggested_n_threads) + { + std::vector> refined_pairs(orig_k); + for (size_t i = omp_get_thread_num(); i < n_queries; i += omp_get_num_threads()) { + // Compute the refined distance using original dataset vectors + const DataT* query = queries.data_handle() + dim * i; + for (size_t j = 0; j < orig_k; j++) { + IdxT id = neighbor_candidates(i, j); + const DataT* row = dataset.data_handle() + dim * id; + DistanceT distance = 0.0; + for (size_t k = 0; k < dim; k++) { + distance += DC::template eval(query[k], row[k]); + } + refined_pairs[j] = std::make_tuple(distance, id); + } + // Sort the query neighbors by their refined distances + std::sort(refined_pairs.begin(), refined_pairs.end()); + // Store first refined_k neighbors + for (size_t j = 0; j < refined_k; j++) { + indices(i, j) = std::get<1>(refined_pairs[j]); + if (distances.data_handle() != nullptr) { + distances(i, j) = DC::template postprocess(std::get<0>(refined_pairs[j])); + } + } + } + } +} + +struct distance_comp_l2 { + template + static inline auto eval(const DistanceT& a, const DistanceT& b) -> DistanceT + { + auto d = a - b; + return d * d; + } + template + static inline auto postprocess(const DistanceT& a) -> DistanceT + { + return a; + } +}; + +struct distance_comp_inner { + template + static inline auto eval(const DistanceT& a, const DistanceT& b) -> DistanceT + { + return -a * b; + } + template + static inline auto postprocess(const DistanceT& a) -> DistanceT + { + return -a; + } +}; + +/** + * Naive CPU implementation of refine operation + * + * All pointers are expected to be accessible on the host. + */ +template +[[gnu::optimize(3), gnu::optimize("tree-vectorize")]] void refine_host( + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded) +{ + refine_check_input(dataset.extents(), + queries.extents(), + neighbor_candidates.extents(), + indices.extents(), + distances.extents(), + metric); + + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + return refine_host_impl( + dataset, queries, neighbor_candidates, indices, distances); + case raft::distance::DistanceType::InnerProduct: + return refine_host_impl( + dataset, queries, neighbor_candidates, indices, distances); + default: throw raft::logic_error("Unsupported metric"); + } +} + +} // namespace raft::neighbors::detail diff --git a/cpp/include/raft/neighbors/detail/refine_host.hpp b/cpp/include/raft/neighbors/detail/refine_host.hpp new file mode 100644 index 0000000000..ff0de75660 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/refine_host.hpp @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#ifndef RAFT_EXPLICIT_INSTANTIATE_ONLY +#include "refine_host-inl.hpp" +#endif + +#ifdef RAFT_COMPILED +#include "refine_host-ext.hpp" +#endif diff --git a/cpp/src/neighbors/detail/refine_host_float_float.cpp b/cpp/src/neighbors/detail/refine_host_float_float.cpp new file mode 100644 index 0000000000..c596200c0a --- /dev/null +++ b/cpp/src/neighbors/detail/refine_host_float_float.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include + +#define instantiate_raft_neighbors_refine(IdxT, DataT, DistanceT, ExtentsT) \ + template void raft::neighbors::detail::refine_host( \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/detail/refine_host_int8_t_float.cpp b/cpp/src/neighbors/detail/refine_host_int8_t_float.cpp new file mode 100644 index 0000000000..334a3e8cb6 --- /dev/null +++ b/cpp/src/neighbors/detail/refine_host_int8_t_float.cpp @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_refine(IdxT, DataT, DistanceT, ExtentsT) \ + template void raft::neighbors::detail::refine_host( \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/detail/refine_host_uint8_t_float.cpp b/cpp/src/neighbors/detail/refine_host_uint8_t_float.cpp new file mode 100644 index 0000000000..43d93e5f2e --- /dev/null +++ b/cpp/src/neighbors/detail/refine_host_uint8_t_float.cpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_refine(IdxT, DataT, DistanceT, ExtentsT) \ + template void raft::neighbors::detail::refine_host( \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine From dfac331a52a252a82eb936d009bedb80cc42bfab Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 18 Jul 2023 18:18:30 +0200 Subject: [PATCH 2/3] Fix missing includes --- cpp/include/raft/neighbors/detail/refine_device.cuh | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/include/raft/neighbors/detail/refine_device.cuh b/cpp/include/raft/neighbors/detail/refine_device.cuh index ac6d0fa2d6..6ee96957fa 100644 --- a/cpp/include/raft/neighbors/detail/refine_device.cuh +++ b/cpp/include/raft/neighbors/detail/refine_device.cuh @@ -26,6 +26,7 @@ #include #include #include +#include #include #include From bd895b188f7a5f59fea40a732622bf1c4622bb1a Mon Sep 17 00:00:00 2001 From: "Artem M. Chirkin" <9253178+achirkin@users.noreply.github.com> Date: Wed, 19 Jul 2023 15:46:19 +0200 Subject: [PATCH 3/3] Update cpp/CMakeLists.txt --- cpp/CMakeLists.txt | 4 ---- 1 file changed, 4 deletions(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 73fc80f02e..81c8192fb2 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -461,10 +461,6 @@ if(RAFT_COMPILE_LIBRARY) # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") - # Optimization: extra compile flags for individual compilation units - file(GLOB REFINE_HOST_SRC_FILES "src/neighbors/detail/refine_host_*.cpp") - set_property(SOURCE ${REFINE_HOST_SRC_FILES} PROPERTY COMPILE_FLAGS "-ftree-vectorize") - endif() if(TARGET raft_lib AND (NOT TARGET raft::raft_lib))