From d9801e87f266554c2c561490faf0fe5d76e66165 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:13:22 +0200 Subject: [PATCH 01/32] MV: add -inl suffix to header paths --- .../detail/pairwise_matrix/{dispatch.cuh => dispatch-inl.cuh} | 0 cpp/include/raft/distance/{distance.cuh => distance-inl.cuh} | 0 .../raft/distance/{fused_l2_nn.cuh => fused_l2_nn-inl.cuh} | 0 .../{coalesced_reduction.cuh => coalesced_reduction-inl.cuh} | 2 +- .../raft/matrix/detail/{select_k.cuh => select_k-inl.cuh} | 0 .../raft/neighbors/{ball_cover.cuh => ball_cover-inl.cuh} | 0 .../raft/neighbors/{brute_force.cuh => brute_force-inl.cuh} | 0 .../detail/{ivf_flat_search.cuh => ivf_flat_search-inl.cuh} | 0 .../detail/{selection_faiss.cuh => selection_faiss-inl.cuh} | 0 cpp/include/raft/neighbors/{ivf_flat.cuh => ivf_flat-inl.cuh} | 0 cpp/include/raft/neighbors/{ivf_pq.cuh => ivf_pq-inl.cuh} | 0 cpp/include/raft/neighbors/{refine.cuh => refine-inl.cuh} | 0 .../knn/detail/ball_cover/{registers.cuh => registers-inl.cuh} | 0 .../knn/detail/{fused_l2_knn.cuh => fused_l2_knn-inl.cuh} | 0 14 files changed, 1 insertion(+), 1 deletion(-) rename cpp/include/raft/distance/detail/pairwise_matrix/{dispatch.cuh => dispatch-inl.cuh} (100%) rename cpp/include/raft/distance/{distance.cuh => distance-inl.cuh} (100%) rename cpp/include/raft/distance/{fused_l2_nn.cuh => fused_l2_nn-inl.cuh} (100%) rename cpp/include/raft/linalg/detail/{coalesced_reduction.cuh => coalesced_reduction-inl.cuh} (99%) rename cpp/include/raft/matrix/detail/{select_k.cuh => select_k-inl.cuh} (100%) rename cpp/include/raft/neighbors/{ball_cover.cuh => ball_cover-inl.cuh} (100%) rename cpp/include/raft/neighbors/{brute_force.cuh => brute_force-inl.cuh} (100%) rename cpp/include/raft/neighbors/detail/{ivf_flat_search.cuh => ivf_flat_search-inl.cuh} (100%) rename cpp/include/raft/neighbors/detail/{selection_faiss.cuh => selection_faiss-inl.cuh} (100%) rename cpp/include/raft/neighbors/{ivf_flat.cuh => ivf_flat-inl.cuh} (100%) rename cpp/include/raft/neighbors/{ivf_pq.cuh => ivf_pq-inl.cuh} (100%) rename cpp/include/raft/neighbors/{refine.cuh => refine-inl.cuh} (100%) rename cpp/include/raft/spatial/knn/detail/ball_cover/{registers.cuh => registers-inl.cuh} (100%) rename cpp/include/raft/spatial/knn/detail/{fused_l2_knn.cuh => fused_l2_knn-inl.cuh} (100%) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh similarity index 100% rename from cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh rename to cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance-inl.cuh similarity index 100% rename from cpp/include/raft/distance/distance.cuh rename to cpp/include/raft/distance/distance-inl.cuh diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh similarity index 100% rename from cpp/include/raft/distance/fused_l2_nn.cuh rename to cpp/include/raft/distance/fused_l2_nn-inl.cuh diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh similarity index 99% rename from cpp/include/raft/linalg/detail/coalesced_reduction.cuh rename to cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index 238e17fa56..4cc19d4a17 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/cpp/include/raft/matrix/detail/select_k.cuh b/cpp/include/raft/matrix/detail/select_k-inl.cuh similarity index 100% rename from cpp/include/raft/matrix/detail/select_k.cuh rename to cpp/include/raft/matrix/detail/select_k-inl.cuh diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/ball_cover.cuh rename to cpp/include/raft/neighbors/ball_cover-inl.cuh diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/brute_force.cuh rename to cpp/include/raft/neighbors/brute_force-inl.cuh diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/detail/ivf_flat_search.cuh rename to cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh diff --git a/cpp/include/raft/neighbors/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/detail/selection_faiss.cuh rename to cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh diff --git a/cpp/include/raft/neighbors/ivf_flat.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/ivf_flat.cuh rename to cpp/include/raft/neighbors/ivf_flat-inl.cuh diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/ivf_pq.cuh rename to cpp/include/raft/neighbors/ivf_pq-inl.cuh diff --git a/cpp/include/raft/neighbors/refine.cuh b/cpp/include/raft/neighbors/refine-inl.cuh similarity index 100% rename from cpp/include/raft/neighbors/refine.cuh rename to cpp/include/raft/neighbors/refine-inl.cuh diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh similarity index 100% rename from cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh rename to cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh similarity index 100% rename from cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh rename to cpp/include/raft/spatial/knn/detail/fused_l2_knn-inl.cuh From 50b374da5c579d7d160c15ebc0d2b1db105798eb Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 13 Apr 2023 16:02:40 +0200 Subject: [PATCH 02/32] MV: raft_runtime src files Instead of having the specializations in sub-directories, the raft_runtime source files now mimic the include/ directory hierarchy. --- cpp/CMakeLists.txt | 72 +++++++++---------- .../cluster/cluster_cost.cuh | 0 .../cluster/cluster_cost_double.cu | 0 .../cluster/cluster_cost_float.cu | 0 .../cluster/kmeans_fit_double.cu | 0 .../cluster/kmeans_fit_float.cu | 0 .../cluster/kmeans_init_plus_plus_double.cu | 0 .../cluster/kmeans_init_plus_plus_float.cu | 0 .../cluster/update_centroids.cuh | 0 .../cluster/update_centroids_double.cu | 0 .../cluster/update_centroids_float.cu | 0 .../distance/fused_l2_min_arg.cu | 0 .../distance/pairwise_distance.cu | 0 .../matrix/select_k_float_int64_t.cu | 0 .../brute_force_knn_int64_t_float.cu | 0 .../neighbors/ivf_flat_build.cu | 0 .../neighbors/ivf_flat_search.cu | 0 .../neighbors/ivfpq_build.cu | 0 .../neighbors/ivfpq_deserialize.cu | 0 .../neighbors/ivfpq_search_float_int64_t.cu | 0 .../neighbors/ivfpq_search_int8_t_int64_t.cu | 0 .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 0 .../neighbors/ivfpq_serialize.cu | 0 .../neighbors/refine_d_int64_t_float.cu | 0 .../neighbors/refine_d_int64_t_int8_t.cu | 0 .../neighbors/refine_d_int64_t_uint8_t.cu | 0 .../neighbors/refine_h_int64_t_float.cu | 0 .../neighbors/refine_h_int64_t_int8_t.cu | 0 .../neighbors/refine_h_int64_t_uint8_t.cu | 0 cpp/src/{ => raft_runtime}/random/common.cuh | 0 ...rmat_rectangular_generator_int64_double.cu | 0 .../rmat_rectangular_generator_int64_float.cu | 0 .../rmat_rectangular_generator_int_double.cu | 0 .../rmat_rectangular_generator_int_float.cu | 0 34 files changed, 32 insertions(+), 40 deletions(-) rename cpp/src/{ => raft_runtime}/cluster/cluster_cost.cuh (100%) rename cpp/src/{ => raft_runtime}/cluster/cluster_cost_double.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/cluster_cost_float.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/kmeans_fit_double.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/kmeans_fit_float.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/kmeans_init_plus_plus_double.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/kmeans_init_plus_plus_float.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/update_centroids.cuh (100%) rename cpp/src/{ => raft_runtime}/cluster/update_centroids_double.cu (100%) rename cpp/src/{ => raft_runtime}/cluster/update_centroids_float.cu (100%) rename cpp/src/{ => raft_runtime}/distance/fused_l2_min_arg.cu (100%) rename cpp/src/{ => raft_runtime}/distance/pairwise_distance.cu (100%) rename cpp/src/{ => raft_runtime}/matrix/select_k_float_int64_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/brute_force_knn_int64_t_float.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivf_flat_build.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivf_flat_search.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivfpq_build.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivfpq_deserialize.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivfpq_search_float_int64_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivfpq_search_int8_t_int64_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivfpq_search_uint8_t_int64_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/ivfpq_serialize.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/refine_d_int64_t_float.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/refine_d_int64_t_int8_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/refine_d_int64_t_uint8_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/refine_h_int64_t_float.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/refine_h_int64_t_int8_t.cu (100%) rename cpp/src/{ => raft_runtime}/neighbors/refine_h_int64_t_uint8_t.cu (100%) rename cpp/src/{ => raft_runtime}/random/common.cuh (100%) rename cpp/src/{ => raft_runtime}/random/rmat_rectangular_generator_int64_double.cu (100%) rename cpp/src/{ => raft_runtime}/random/rmat_rectangular_generator_int64_float.cu (100%) rename cpp/src/{ => raft_runtime}/random/rmat_rectangular_generator_int_double.cu (100%) rename cpp/src/{ => raft_runtime}/random/rmat_rectangular_generator_int_float.cu (100%) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 62f9ac604e..1620a55053 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -263,28 +263,12 @@ set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) add_library( raft_lib - src/distance/pairwise_distance.cu - src/distance/fused_l2_min_arg.cu - src/cluster/update_centroids_float.cu - src/cluster/update_centroids_double.cu - src/cluster/cluster_cost_float.cu - src/cluster/cluster_cost_double.cu - src/neighbors/refine_d_int64_t_float.cu - src/neighbors/refine_d_int64_t_int8_t.cu - src/neighbors/refine_d_int64_t_uint8_t.cu - src/neighbors/refine_h_int64_t_float.cu - src/neighbors/refine_h_int64_t_int8_t.cu - src/neighbors/refine_h_int64_t_uint8_t.cu src/neighbors/specializations/refine_d_int64_t_float.cu src/neighbors/specializations/refine_d_int64_t_int8_t.cu src/neighbors/specializations/refine_d_int64_t_uint8_t.cu src/neighbors/specializations/refine_h_int64_t_float.cu src/neighbors/specializations/refine_h_int64_t_int8_t.cu src/neighbors/specializations/refine_h_int64_t_uint8_t.cu - src/cluster/kmeans_fit_float.cu - src/cluster/kmeans_fit_double.cu - src/cluster/kmeans_init_plus_plus_double.cu - src/cluster/kmeans_init_plus_plus_float.cu src/distance/specializations/detail/canberra_double_double_double_int.cu src/distance/specializations/detail/canberra_float_float_float_int.cu src/distance/specializations/detail/correlation_double_double_double_int.cu @@ -306,7 +290,6 @@ if(RAFT_COMPILE_LIBRARY) # These are somehow missing a kernel definition which is causing a compile error. # src/distance/specializations/detail/kernels/rbf_kernel_double.cu # src/distance/specializations/detail/kernels/rbf_kernel_float.cu - src/neighbors/brute_force_knn_int64_t_float.cu src/distance/specializations/detail/kernels/tanh_kernel_double.cu src/distance/specializations/detail/kernels/tanh_kernel_float.cu src/distance/specializations/detail/kl_divergence_float_float_float_int.cu @@ -327,17 +310,10 @@ if(RAFT_COMPILE_LIBRARY) src/distance/specializations/fused_l2_nn_double_int64.cu src/distance/specializations/fused_l2_nn_float_int.cu src/distance/specializations/fused_l2_nn_float_int64.cu - src/matrix/select_k_float_int64_t.cu src/matrix/specializations/detail/select_k_float_uint32_t.cu src/matrix/specializations/detail/select_k_float_int64_t.cu src/matrix/specializations/detail/select_k_half_uint32_t.cu src/matrix/specializations/detail/select_k_half_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu - src/neighbors/ivfpq_search_float_int64_t.cu - src/neighbors/ivfpq_search_int8_t_int64_t.cu - src/neighbors/ivfpq_search_uint8_t_int64_t.cu src/neighbors/specializations/ivfpq_build_float_int64_t.cu src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu @@ -372,10 +348,6 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu @@ -387,8 +359,6 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/specializations/fused_l2_knn_long_float_false.cu src/neighbors/specializations/fused_l2_knn_int_float_true.cu src/neighbors/specializations/fused_l2_knn_int_float_false.cu - src/neighbors/ivf_flat_search.cu - src/neighbors/ivf_flat_build.cu src/neighbors/specializations/ivfflat_build_float_int64_t.cu src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu @@ -398,12 +368,6 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/specializations/ivfflat_search_float_int64_t.cu src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu - src/neighbors/ivfpq_build.cu - src/neighbors/ivfpq_deserialize.cu - src/neighbors/ivfpq_serialize.cu - src/neighbors/ivfpq_search_float_int64_t.cu - src/neighbors/ivfpq_search_int8_t_int64_t.cu - src/neighbors/ivfpq_search_uint8_t_int64_t.cu src/neighbors/specializations/ivfpq_build_float_int64_t.cu src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu @@ -434,10 +398,38 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/random/rmat_rectangular_generator_int_double.cu - src/random/rmat_rectangular_generator_int64_double.cu - src/random/rmat_rectangular_generator_int_float.cu - src/random/rmat_rectangular_generator_int64_float.cu + src/raft_runtime/cluster/cluster_cost.cuh + src/raft_runtime/cluster/cluster_cost_double.cu + src/raft_runtime/cluster/cluster_cost_float.cu + src/raft_runtime/cluster/kmeans_fit_double.cu + src/raft_runtime/cluster/kmeans_fit_float.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu + src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu + src/raft_runtime/cluster/update_centroids.cuh + src/raft_runtime/cluster/update_centroids_double.cu + src/raft_runtime/cluster/update_centroids_float.cu + src/raft_runtime/distance/fused_l2_min_arg.cu + src/raft_runtime/distance/pairwise_distance.cu + src/raft_runtime/matrix/select_k_float_int64_t.cu + src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu + src/raft_runtime/neighbors/ivf_flat_build.cu + src/raft_runtime/neighbors/ivf_flat_search.cu + src/raft_runtime/neighbors/ivfpq_build.cu + src/raft_runtime/neighbors/ivfpq_deserialize.cu + src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu + src/raft_runtime/neighbors/ivfpq_serialize.cu + src/raft_runtime/neighbors/refine_d_int64_t_float.cu + src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_float.cu + src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu + src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu + src/raft_runtime/random/rmat_rectangular_generator_int_double.cu + src/raft_runtime/random/rmat_rectangular_generator_int_float.cu ) set_target_properties( raft_lib diff --git a/cpp/src/cluster/cluster_cost.cuh b/cpp/src/raft_runtime/cluster/cluster_cost.cuh similarity index 100% rename from cpp/src/cluster/cluster_cost.cuh rename to cpp/src/raft_runtime/cluster/cluster_cost.cuh diff --git a/cpp/src/cluster/cluster_cost_double.cu b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu similarity index 100% rename from cpp/src/cluster/cluster_cost_double.cu rename to cpp/src/raft_runtime/cluster/cluster_cost_double.cu diff --git a/cpp/src/cluster/cluster_cost_float.cu b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu similarity index 100% rename from cpp/src/cluster/cluster_cost_float.cu rename to cpp/src/raft_runtime/cluster/cluster_cost_float.cu diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu similarity index 100% rename from cpp/src/cluster/kmeans_fit_double.cu rename to cpp/src/raft_runtime/cluster/kmeans_fit_double.cu diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu similarity index 100% rename from cpp/src/cluster/kmeans_fit_float.cu rename to cpp/src/raft_runtime/cluster/kmeans_fit_float.cu diff --git a/cpp/src/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu similarity index 100% rename from cpp/src/cluster/kmeans_init_plus_plus_double.cu rename to cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu diff --git a/cpp/src/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu similarity index 100% rename from cpp/src/cluster/kmeans_init_plus_plus_float.cu rename to cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu diff --git a/cpp/src/cluster/update_centroids.cuh b/cpp/src/raft_runtime/cluster/update_centroids.cuh similarity index 100% rename from cpp/src/cluster/update_centroids.cuh rename to cpp/src/raft_runtime/cluster/update_centroids.cuh diff --git a/cpp/src/cluster/update_centroids_double.cu b/cpp/src/raft_runtime/cluster/update_centroids_double.cu similarity index 100% rename from cpp/src/cluster/update_centroids_double.cu rename to cpp/src/raft_runtime/cluster/update_centroids_double.cu diff --git a/cpp/src/cluster/update_centroids_float.cu b/cpp/src/raft_runtime/cluster/update_centroids_float.cu similarity index 100% rename from cpp/src/cluster/update_centroids_float.cu rename to cpp/src/raft_runtime/cluster/update_centroids_float.cu diff --git a/cpp/src/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu similarity index 100% rename from cpp/src/distance/fused_l2_min_arg.cu rename to cpp/src/raft_runtime/distance/fused_l2_min_arg.cu diff --git a/cpp/src/distance/pairwise_distance.cu b/cpp/src/raft_runtime/distance/pairwise_distance.cu similarity index 100% rename from cpp/src/distance/pairwise_distance.cu rename to cpp/src/raft_runtime/distance/pairwise_distance.cu diff --git a/cpp/src/matrix/select_k_float_int64_t.cu b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu similarity index 100% rename from cpp/src/matrix/select_k_float_int64_t.cu rename to cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu similarity index 100% rename from cpp/src/neighbors/brute_force_knn_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu diff --git a/cpp/src/neighbors/ivf_flat_build.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu similarity index 100% rename from cpp/src/neighbors/ivf_flat_build.cu rename to cpp/src/raft_runtime/neighbors/ivf_flat_build.cu diff --git a/cpp/src/neighbors/ivf_flat_search.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu similarity index 100% rename from cpp/src/neighbors/ivf_flat_search.cu rename to cpp/src/raft_runtime/neighbors/ivf_flat_search.cu diff --git a/cpp/src/neighbors/ivfpq_build.cu b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu similarity index 100% rename from cpp/src/neighbors/ivfpq_build.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_build.cu diff --git a/cpp/src/neighbors/ivfpq_deserialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu similarity index 100% rename from cpp/src/neighbors/ivfpq_deserialize.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu diff --git a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu similarity index 100% rename from cpp/src/neighbors/ivfpq_search_float_int64_t.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu diff --git a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu similarity index 100% rename from cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu diff --git a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu similarity index 100% rename from cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu diff --git a/cpp/src/neighbors/ivfpq_serialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu similarity index 100% rename from cpp/src/neighbors/ivfpq_serialize.cu rename to cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu diff --git a/cpp/src/neighbors/refine_d_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu similarity index 100% rename from cpp/src/neighbors/refine_d_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu diff --git a/cpp/src/neighbors/refine_d_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu similarity index 100% rename from cpp/src/neighbors/refine_d_int64_t_int8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu diff --git a/cpp/src/neighbors/refine_d_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu similarity index 100% rename from cpp/src/neighbors/refine_d_int64_t_uint8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu diff --git a/cpp/src/neighbors/refine_h_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu similarity index 100% rename from cpp/src/neighbors/refine_h_int64_t_float.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu diff --git a/cpp/src/neighbors/refine_h_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu similarity index 100% rename from cpp/src/neighbors/refine_h_int64_t_int8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu diff --git a/cpp/src/neighbors/refine_h_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu similarity index 100% rename from cpp/src/neighbors/refine_h_int64_t_uint8_t.cu rename to cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu diff --git a/cpp/src/random/common.cuh b/cpp/src/raft_runtime/random/common.cuh similarity index 100% rename from cpp/src/random/common.cuh rename to cpp/src/raft_runtime/random/common.cuh diff --git a/cpp/src/random/rmat_rectangular_generator_int64_double.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int64_double.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_double.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int64_float.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int64_float.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int_double.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int_double.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int_double.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int_double.cu diff --git a/cpp/src/random/rmat_rectangular_generator_int_float.cu b/cpp/src/raft_runtime/random/rmat_rectangular_generator_int_float.cu similarity index 100% rename from cpp/src/random/rmat_rectangular_generator_int_float.cu rename to cpp/src/raft_runtime/random/rmat_rectangular_generator_int_float.cu From 8974ae32255e702de9d707da34b353b61bedbc23 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 13:41:56 +0200 Subject: [PATCH 03/32] FIX: add missing includes --- cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h | 1 + cpp/bench/prims/distance/fused_l2_nn.cu | 1 + cpp/include/raft/cluster/detail/kmeans_common.cuh | 1 + cpp/include/raft/core/mdarray.hpp | 1 + cpp/include/raft/core/resource/device_memory_resource.hpp | 3 ++- cpp/include/raft/core/resources.hpp | 3 ++- cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh | 4 ++-- cpp/include/raft/matrix/detail/select_warpsort.cuh | 2 +- cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh | 1 + cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh | 1 + cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 1 + cpp/include/raft/neighbors/detail/refine.cuh | 1 + cpp/include/raft/spatial/knn/detail/ann_utils.cuh | 1 - cpp/internal/raft_internal/matrix/select_k.cuh | 1 + cpp/src/raft_runtime/distance/fused_l2_min_arg.cu | 3 ++- cpp/src/raft_runtime/neighbors/ivf_flat_build.cu | 1 + cpp/src/raft_runtime/neighbors/ivf_flat_search.cu | 1 + cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu | 2 +- cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu | 1 + cpp/test/core/handle.cpp | 1 + cpp/test/linalg/eigen_solvers.cu | 3 ++- cpp/test/neighbors/ann_ivf_pq.cuh | 1 + cpp/test/neighbors/ann_utils.cuh | 1 + cpp/test/neighbors/tiled_knn.cu | 3 +++ 24 files changed, 30 insertions(+), 9 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 8b2a7d329b..0a80eef1b5 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index 1c45572782..a5115407dd 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -16,6 +16,7 @@ #include #include +#include #include #if defined RAFT_COMPILED #include diff --git a/cpp/include/raft/cluster/detail/kmeans_common.cuh b/cpp/include/raft/cluster/detail/kmeans_common.cuh index 76fc22e99e..cca1cbb6e9 100644 --- a/cpp/include/raft/cluster/detail/kmeans_common.cuh +++ b/cpp/include/raft/cluster/detail/kmeans_common.cuh @@ -38,6 +38,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/mdarray.hpp b/cpp/include/raft/core/mdarray.hpp index 88f90485dd..467a67f786 100644 --- a/cpp/include/raft/core/mdarray.hpp +++ b/cpp/include/raft/core/mdarray.hpp @@ -25,6 +25,7 @@ #include #include +#include #include #include #include diff --git a/cpp/include/raft/core/resource/device_memory_resource.hpp b/cpp/include/raft/core/resource/device_memory_resource.hpp index 35ae3d715f..ebc41e0f8e 100644 --- a/cpp/include/raft/core/resource/device_memory_resource.hpp +++ b/cpp/include/raft/core/resource/device_memory_resource.hpp @@ -18,6 +18,7 @@ #include #include #include +#include namespace raft::resource { class device_memory_resource : public resource { @@ -72,4 +73,4 @@ inline void set_workspace_resource(resources const& res, rmm::mr::device_memory_ { res.add_resource_factory(std::make_shared(mr)); }; -} // namespace raft::resource \ No newline at end of file +} // namespace raft::resource diff --git a/cpp/include/raft/core/resources.hpp b/cpp/include/raft/core/resources.hpp index 64e281e934..49836ee962 100644 --- a/cpp/include/raft/core/resources.hpp +++ b/cpp/include/raft/core/resources.hpp @@ -18,6 +18,7 @@ #include "resource/resource_types.hpp" #include #include +#include // RAFT_EXPECTS #include #include #include @@ -128,4 +129,4 @@ class resources { mutable std::vector factories_; mutable std::vector resources_; }; -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh index 4cc19d4a17..5b01196cf4 100644 --- a/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-inl.cuh @@ -17,7 +17,7 @@ #pragma once #include -#include +#include #include #include #include @@ -365,4 +365,4 @@ void coalescedReduction(OutType* dots, } // namespace detail } // namespace linalg -} // namespace raft \ No newline at end of file +} // namespace raft diff --git a/cpp/include/raft/matrix/detail/select_warpsort.cuh b/cpp/include/raft/matrix/detail/select_warpsort.cuh index d362b73792..5f3d0e6bc7 100644 --- a/cpp/include/raft/matrix/detail/select_warpsort.cuh +++ b/cpp/include/raft/matrix/detail/select_warpsort.cuh @@ -27,7 +27,7 @@ #include #include -#include +#include #include /* diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index e6533eaf51..dcd7c4f9b7 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index 1bb7f97123..bec3b890eb 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -21,6 +21,7 @@ #include #include #include +#include #include diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index a776ce2586..0148a1a887 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -36,6 +36,7 @@ #include #include #include +#include #include #include diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index aedfc42698..722a4b15f9 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -20,6 +20,7 @@ #include #include #include +#include #include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index 395714a161..d8fe216a85 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -17,7 +17,6 @@ #pragma once #include -#include #include #include #include diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index a3535f8ffd..cd19ad9781 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index b682446cc2..487c7e3a4a 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -20,6 +20,7 @@ #include #include #include +#include #include #include @@ -95,4 +96,4 @@ void fused_l2_nn_min_arg(raft::device_resources const& handle, compute_fused_l2_nn_min_arg(handle, min, x, y, m, n, k, sqrt); } -} // end namespace raft::runtime::distance \ No newline at end of file +} // end namespace raft::runtime::distance diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu index 0d82fdbb08..79743662bd 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu index b843ee7c30..27f9ab0c5b 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu index 8d54e3cc55..76a1839084 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu @@ -15,8 +15,8 @@ */ #include +#include #include - #include namespace raft::runtime::neighbors::ivf_pq { diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu index e251f1442f..7a10fc6caa 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu @@ -15,6 +15,7 @@ */ #include +#include #include #include diff --git a/cpp/test/core/handle.cpp b/cpp/test/core/handle.cpp index 9f416d3ae8..fddfd58bb8 100644 --- a/cpp/test/core/handle.cpp +++ b/cpp/test/core/handle.cpp @@ -22,6 +22,7 @@ #include #include #include +#include #include namespace raft { diff --git a/cpp/test/linalg/eigen_solvers.cu b/cpp/test/linalg/eigen_solvers.cu index 1f29d7e275..ca34b0c3a4 100644 --- a/cpp/test/linalg/eigen_solvers.cu +++ b/cpp/test/linalg/eigen_solvers.cu @@ -14,8 +14,8 @@ * limitations under the License. */ -#include #include +#include #include #include @@ -24,6 +24,7 @@ #include #include #include +#include namespace raft { namespace spectral { diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 07efcb099e..95f08ad3d1 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -27,6 +27,7 @@ #include #include #include +#include #include #ifdef RAFT_COMPILED #include diff --git a/cpp/test/neighbors/ann_utils.cuh b/cpp/test/neighbors/ann_utils.cuh index fc448f014f..438c56da21 100644 --- a/cpp/test/neighbors/ann_utils.cuh +++ b/cpp/test/neighbors/ann_utils.cuh @@ -16,6 +16,7 @@ #pragma once +#include // raft::make_device_matrix #include #include #include diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index ccc3a64edd..03ddba1c64 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -20,10 +20,13 @@ #include #include +#include // raft::distance::pairwise_distance #include #include #include #include +#include // raft::neighbors::detail::brute_force_knn_impl +#include // raft::neighbors::detail::select_k #if defined RAFT_COMPILED #include From 95ef31bc974b6b2b4739cedd6d4d287e9e0266be Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:40:23 +0200 Subject: [PATCH 04/32] FIX: getWorkspaceSize Used .data() instead of .data_handle() --- cpp/include/raft/distance/distance-inl.cuh | 251 ++++++++++----------- 1 file changed, 123 insertions(+), 128 deletions(-) diff --git a/cpp/include/raft/distance/distance-inl.cuh b/cpp/include/raft/distance/distance-inl.cuh index 5216902635..3399443765 100644 --- a/cpp/include/raft/distance/distance-inl.cuh +++ b/cpp/include/raft/distance/distance-inl.cuh @@ -13,9 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __DISTANCE_H -#define __DISTANCE_H - #pragma once #include @@ -38,11 +35,11 @@ namespace distance { /** * @brief Evaluate pairwise distances with the user epilogue lamba allowed * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type * @tparam FinalLambda user-defined epilogue lamba - * @tparam Index_ Index type + * @tparam IdxT Index type * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -57,40 +54,40 @@ namespace distance { * @param metric_arg metric argument (used for Minkowski distance) * * @note fin_op: This is a device lambda which is supposed to operate upon the - * input which is AccType and returns the output in OutType. It's signature is - * as follows:
OutType fin_op(AccType in, int g_idx);
. If one needs + * input which is AccT and returns the output in OutT. It's signature is + * as follows:
OutT fin_op(AccT in, int g_idx);
. If one needs * any other parameters, feel free to pass them via closure. */ -template + typename IdxT = int> void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, void* workspace, size_t worksize, FinalLambda fin_op, - bool isRowMajor = true, - InType metric_arg = 2.0f) + bool isRowMajor = true, + DataT metric_arg = 2.0f) { - detail::distance( + detail::distance( handle, x, y, dist, m, n, k, workspace, worksize, fin_op, isRowMajor, metric_arg); } /** * @brief Evaluate pairwise distances for the simple use case * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -103,89 +100,89 @@ void distance(raft::resources const& handle, * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, void* workspace, size_t worksize, - bool isRowMajor = true, - InType metric_arg = 2.0f) + bool isRowMajor = true, + DataT metric_arg = 2.0f) { - detail::distance( + detail::distance( handle, x, y, dist, m, n, k, workspace, worksize, isRowMajor, metric_arg); } /** * @brief Return the exact workspace size to compute the distance * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type * @param x first set of points * @param y second set of points * @param m number of points in x * @param n number of points in y * @param k dimensionality * - * @note If the specified distanceType doesn't need the workspace at all, it + * @note If the specified DistT doesn't need the workspace at all, it * returns 0. */ -template -size_t getWorkspaceSize(const InType* x, const InType* y, Index_ m, Index_ n, Index_ k) +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) { - return detail::getWorkspaceSize(x, y, m, n, k); + return detail::getWorkspaceSize(x, y, m, n, k); } /** * @brief Return the exact workspace size to compute the distance * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type * @param x first set of points (size m*k) * @param y second set of points (size n*k) * @return number of bytes needed in workspace * - * @note If the specified distanceType doesn't need the workspace at all, it + * @note If the specified DistT doesn't need the workspace at all, it * returns 0. */ -template -size_t getWorkspaceSize(const raft::device_matrix_view x, - const raft::device_matrix_view y) +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) { RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); - return getWorkspaceSize( - x.data(), y.data(), x.extent(0), y.extent(0), x.extent(1)); + return getWorkspaceSize( + x.data_handle(), y.data_handle(), x.extent(0), y.extent(0), x.extent(1)); } /** * @brief Evaluate pairwise distances for the simple use case * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -196,26 +193,26 @@ size_t getWorkspaceSize(const raft::device_matrix_view x, * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void distance(raft::resources const& handle, - const InType* x, - const InType* y, - OutType* dist, - Index_ m, - Index_ n, - Index_ k, - bool isRowMajor = true, - InType metric_arg = 2.0f) + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) { auto stream = raft::resource::get_cuda_stream(handle); rmm::device_uvector workspace(0, stream); - auto worksize = getWorkspaceSize(x, y, m, n, k); + auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); - detail::distance( + detail::distance( handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); } @@ -223,7 +220,7 @@ void distance(raft::resources const& handle, * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type + * @tparam IdxT indexing type * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -237,14 +234,14 @@ void distance(raft::resources const& handle, * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, Type* dist, - Index_ m, - Index_ n, - Index_ k, + IdxT m, + IdxT n, + IdxT k, rmm::device_uvector& workspace, raft::distance::DistanceType metric, bool isRowMajor = true, @@ -253,9 +250,9 @@ void pairwise_distance(raft::resources const& handle, cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto dispatch = [&](auto distance_type) { - auto worksize = getWorkspaceSize(x, y, m, n, k); + auto worksize = getWorkspaceSize(x, y, m, n, k); workspace.resize(worksize, stream); - detail::distance( + detail::distance( handle, x, y, dist, m, n, k, workspace.data(), worksize, isRowMajor, metric_arg); }; @@ -316,7 +313,7 @@ void pairwise_distance(raft::resources const& handle, * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type + * @tparam IdxT indexing type * @param handle raft handle for managing expensive resources * @param x first set of points * @param y second set of points @@ -328,21 +325,21 @@ void pairwise_distance(raft::resources const& handle, * @param isRowMajor whether the matrices are row-major or col-major * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void pairwise_distance(raft::resources const& handle, const Type* x, const Type* y, Type* dist, - Index_ m, - Index_ n, - Index_ k, + IdxT m, + IdxT n, + IdxT k, raft::distance::DistanceType metric, bool isRowMajor = true, Type metric_arg = 2.0f) { auto stream = raft::resource::get_cuda_stream(handle); rmm::device_uvector workspace(0, stream); - pairwise_distance( + pairwise_distance( handle, x, y, dist, m, n, k, workspace, metric, isRowMajor, metric_arg); } @@ -379,27 +376,27 @@ void pairwise_distance(raft::resources const& handle, * @endcode * * @tparam DistanceType which distance to evaluate - * @tparam InType input argument type - * @tparam AccType accumulation type - * @tparam OutType output type - * @tparam Index_ Index type + * @tparam DataT input argument type + * @tparam AccT accumulation type + * @tparam OutT output type + * @tparam IdxT Index type * @param handle raft handle for managing expensive resources * @param x first set of points (size n*k) * @param y second set of points (size m*k) * @param dist output distance matrix (size n*m) * @param metric_arg metric argument (used for Minkowski distance) */ -template + typename IdxT = int> void distance(raft::resources const& handle, - raft::device_matrix_view const x, - raft::device_matrix_view const y, - raft::device_matrix_view dist, - InType metric_arg = 2.0f) + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) { RAFT_EXPECTS(x.extent(1) == y.extent(1), "Number of columns must be equal."); RAFT_EXPECTS(dist.extent(0) == x.extent(0), @@ -414,22 +411,22 @@ void distance(raft::resources const& handle, constexpr auto is_rowmajor = std::is_same_v; - distance(handle, - x.data_handle(), - y.data_handle(), - dist.data_handle(), - x.extent(0), - y.extent(0), - x.extent(1), - is_rowmajor, - metric_arg); + distance(handle, + x.data_handle(), + y.data_handle(), + dist.data_handle(), + x.extent(0), + y.extent(0), + x.extent(1), + is_rowmajor, + metric_arg); } /** * @brief Convenience wrapper around 'distance' prim to convert runtime metric * into compile time for the purpose of dispatch * @tparam Type input/accumulation/output data-type - * @tparam Index_ indexing type + * @tparam IdxT indexing type * @param handle raft handle for managing expensive resources * @param x first matrix of points (size mxk) * @param y second matrix of points (size nxk) @@ -437,11 +434,11 @@ void distance(raft::resources const& handle, * @param metric distance metric * @param metric_arg metric argument (used for Minkowski distance) */ -template +template void pairwise_distance(raft::resources const& handle, - device_matrix_view const x, - device_matrix_view const y, - device_matrix_view dist, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, raft::distance::DistanceType metric, Type metric_arg = 2.0f) { @@ -478,5 +475,3 @@ void pairwise_distance(raft::resources const& handle, }; // namespace distance }; // namespace raft - -#endif From 7edcb6a59da62216d9521e15c3848ba676588f31 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:29:39 +0200 Subject: [PATCH 05/32] PREP: Separate rbf_fin_op --- .../detail/kernels/kernel_matrices.cuh | 5 +- .../distance/detail/kernels/rbf_fin_op.cuh | 51 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 cpp/include/raft/distance/detail/kernels/rbf_fin_op.cuh diff --git a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh index d1465efdb0..1b111e77f1 100644 --- a/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh +++ b/cpp/include/raft/distance/detail/kernels/kernel_matrices.cuh @@ -17,10 +17,11 @@ #pragma once #include "gram_matrix.cuh" -#include +#include #include #include +#include namespace raft::distance::kernels::detail { @@ -353,7 +354,7 @@ class RBFKernel : public GramMatrixBase { math_t gain = this->gain; using index_t = int64_t; - auto fin_op = [gain] __device__(math_t d_val, index_t idx) { return exp(-gain * d_val); }; + rbf_fin_op fin_op{gain}; raft::distance::distance // raft::exp +#include // HD + +namespace raft::distance::kernels::detail { + +/** @brief: Final op for Gram matrix with RBF kernel. + * + * Calculates output = e^(-gain * in) + * + */ +template +struct rbf_fin_op { + OutT gain; + + explicit HD rbf_fin_op(OutT gain_) noexcept : gain(gain_) {} + + template + HDI OutT operator()(OutT d_val, Args... unused_args) + { + return raft::exp(-gain * d_val); + } +}; // struct rbf_fin_op + +} // namespace raft::distance::kernels::detail From 71de7bd4e33e788c17bf7c27893d764e3c410b06 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:37:37 +0200 Subject: [PATCH 06/32] PREP: registers: Add _types header --- .../spatial/knn/detail/ball_cover/common.cuh | 39 +---------- .../knn/detail/ball_cover/registers-inl.cuh | 1 + .../knn/detail/ball_cover/registers_types.cuh | 66 +++++++++++++++++++ 3 files changed, 69 insertions(+), 37 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh index 0a6718f5a5..ce72b2648f 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/common.cuh @@ -17,6 +17,7 @@ #pragma once #include "../haversine_distance.cuh" +#include "registers_types.cuh" #include #include #include @@ -39,42 +40,6 @@ struct NNComp { } }; -template -struct DistFunc { - virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) - { - return -1; - }; -}; - -template -struct HaversineFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); - } -}; - -template -struct EuclideanFunc : public DistFunc { - __device__ __host__ __forceinline__ value_t operator()(const value_t* a, - const value_t* b, - const value_int n_dims) override - { - value_t sum_sq = 0; - for (value_int i = 0; i < n_dims; ++i) { - value_t diff = a[i] - b[i]; - sum_sq += diff * diff; - } - - return raft::sqrt(sum_sq); - } -}; - /** * Zeros the bit at location h in a one-hot encoded 32-bit int array */ @@ -105,4 +70,4 @@ __device__ inline bool _get_val(std::uint32_t* arr, std::uint32_t h) }; // namespace detail }; // namespace knn }; // namespace spatial -}; // namespace raft \ No newline at end of file +}; // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh index f665368c41..e0e7d716ee 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-inl.cuh @@ -20,6 +20,7 @@ #include "../../ball_cover_types.hpp" #include "../haversine_distance.cuh" +#include "registers_types.cuh" // DistFunc #include #include diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh new file mode 100644 index 0000000000..7f4268d2dc --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers_types.cuh @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../haversine_distance.cuh" // compute_haversine +#include // uint32_t + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +struct DistFunc { + virtual __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) + { + return -1; + }; +}; + +template +struct HaversineFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + return raft::spatial::knn::detail::compute_haversine(a[0], b[0], a[1], b[1]); + } +}; + +template +struct EuclideanFunc : public DistFunc { + __device__ __host__ __forceinline__ value_t operator()(const value_t* a, + const value_t* b, + const value_int n_dims) override + { + value_t sum_sq = 0; + for (value_int i = 0; i < n_dims; ++i) { + value_t diff = a[i] - b[i]; + sum_sq += diff * diff; + } + + return raft::sqrt(sum_sq); + } +}; + +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft From 541cabc385d193c908bd2faf39e6a9ce3c58dfc1 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 14:50:01 +0200 Subject: [PATCH 07/32] Change RAFT_COMPILED from INTERFACE to PUBLIC --- cpp/CMakeLists.txt | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 1620a55053..4f261a26ec 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -455,7 +455,10 @@ if(RAFT_COMPILE_LIBRARY) raft_lib PRIVATE "$<$:${RAFT_CXX_FLAGS}>" "$<$:${RAFT_CUDA_FLAGS}>" ) - target_compile_definitions(raft_lib INTERFACE "RAFT_COMPILED") + + # RAFT_COMPILED is set during compilation of libraft.so as well as downstream libraries (due to + # "PUBLIC") + target_compile_definitions(raft_lib PUBLIC "RAFT_COMPILED") # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") From d81b14ee5e25ac78e042671323765659f5f693af Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 14:52:41 +0200 Subject: [PATCH 08/32] Define RAFT_EXPLICIT and RAFT_EXPLICIT_INSTANTIATE_ONLY --- cpp/CMakeLists.txt | 3 + cpp/include/raft/util/raft_explicit.hpp | 89 +++++++++++++++++++++++++ 2 files changed, 92 insertions(+) create mode 100644 cpp/include/raft/util/raft_explicit.hpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 4f261a26ec..2c2e5370cf 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -460,6 +460,9 @@ if(RAFT_COMPILE_LIBRARY) # "PUBLIC") target_compile_definitions(raft_lib PUBLIC "RAFT_COMPILED") + # RAFT_EXPLICIT_INSTANTIATE_ONLY is set during compilation of libraft.so (due to "PRIVATE") + target_compile_definitions(raft_lib PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + # ensure CUDA symbols aren't relocated to the middle of the debug build binaries target_link_options(raft_lib PRIVATE "${CMAKE_CURRENT_BINARY_DIR}/fatbin.ld") diff --git a/cpp/include/raft/util/raft_explicit.hpp b/cpp/include/raft/util/raft_explicit.hpp new file mode 100644 index 0000000000..7edb2f0b42 --- /dev/null +++ b/cpp/include/raft/util/raft_explicit.hpp @@ -0,0 +1,89 @@ +/* Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/** + * @brief Prevents a function template from being implicitly instantiated + * + * This macro defines a function body that can be used for function template + * definitions of functions that should not be implicitly instantiated. + * + * When the template is erroneously implicitly instantiated, it provides a + * useful error message that tells the user how to avoid the implicit + * instantiation. + * + * The error message is generated using a static assert. It is generally tricky + * to have a static assert fire only when you want it, as documented in + * P2593: https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + * + * We use the strategy from paragraph 1.3 here. We define a struct + * `not_allowed`, whose type is dependent on the template parameters of the + * enclosing function instance. We use this struct type to instantiate the + * `implicit_instantiation` template class, whose value is always false. We pass + * this value to static_assert. This way, the static assert only fires when the + * template is instantiated, since `implicit_instantiation` cannot be + * instantiated without all the types in the enclosing function template. + */ +#define RAFT_EXPLICIT \ + { \ + /* Type of `not_allowed` depends on template parameters of enclosing function. */ \ + struct not_allowed { \ + }; \ + static_assert( \ + raft::util::raft_explicit::implicit_instantiation::value, \ + "ACCIDENTAL_IMPLICIT_INSTANTIATION\n\n" \ + \ + "If you see this error, then you have implicitly instantiated a function\n" \ + "template. To keep compile times in check, libraft has the policy of\n" \ + "explicitly instantiating templates. To fix the compilation error, follow\n" \ + "these steps.\n\n" \ + \ + "If you scroll up or down a bit, you probably saw a line like the following:\n\n" \ + \ + "detected during instantiation of \"void raft::foo(T) [with T=float]\" at line [..]\n\n" \ + \ + "Simplest temporary solution:\n\n" \ + \ + " Add '#undef RAFT_EXPLICIT_INSTANTIATE_ONLY' at the top of your .cpp/.cu file.\n\n" \ + \ + "Best solution:\n\n" \ + \ + " 1. Add the following line to the file include/raft/foo.hpp:\n\n" \ + \ + " extern template void raft::foo(double);\n\n" \ + \ + " 2. Add the following line to the file src/raft/foo.cpp:\n\n" \ + \ + " template void raft::foo(double)\n"); \ + \ + /* Function may have non-void return type. */ \ + /* To prevent warnings/errors about missing returns, throw an exception. */ \ + throw "raft_explicit_error"; \ + } + +namespace raft::util::raft_explicit { +/** + * @brief Template that is always false + * + * This template is from paragraph 1.3 of P2593: + * https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html + * + * The value of `value` is always false, but it depends on a template parameter. + */ +template +struct implicit_instantiation { + static constexpr bool value = false; +}; +} // namespace raft::util::raft_explicit From 48ea76957413b9f3238fbac2ce456fde893c60e4 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:07:46 +0200 Subject: [PATCH 09/32] Update docs --- README.md | 14 +++--- docs/source/build.md | 12 +---- docs/source/developer_guide.md | 91 ++++++++++++++++++++++++++++++++++ docs/source/using_libraft.md | 91 ++++++++++++++++++---------------- 4 files changed, 148 insertions(+), 60 deletions(-) diff --git a/README.md b/README.md index b77e906262..3bbb5de5fa 100755 --- a/README.md +++ b/README.md @@ -198,7 +198,7 @@ RAFT itself can be installed through conda, [CMake Package Manager (CPM)](https: The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library of pre-compiled template specializations and runtime APIs. +- `libraft` (optional) shared library of pre-compiled template instantiations and runtime APIs. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -231,11 +231,11 @@ You can find an [example RAFT](cpp/template/README.md) project template in the ` Additional CMake targets can be made available by adding components in the table below to the `RAFT_COMPONENTS` list above, separated by spaces. The `raft::raft` target will always be available. RAFT headers require, at a minimum, the CUDA toolkit libraries and RMM dependencies. -| Component | Target | Description | Base Dependencies | -|-------------|---------------------|-----------------------------------------------------------|---------------------------------------| -| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | -| compiled | `raft::compiled` | Pre-compiled template specializations and runtime library | raft::raft | -| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | +| Component | Target | Description | Base Dependencies | +|-------------|---------------------|----------------------------------------------------------|----------------------------------------| +| n/a | `raft::raft` | Full RAFT header library | CUDA toolkit, RMM, NVTX, CCCL, CUTLASS | +| compiled | `raft::compiled` | Pre-compiled template instantiations and runtime library | raft::raft | +| distributed | `raft::distributed` | Dependencies for `raft::comms` APIs | raft::raft, UCX, NCCL | ### Source @@ -282,7 +282,7 @@ The folder structure mirrors other RAPIDS repos, with the following folders: - `util`: Various reusable tools and utilities for accelerated algorithm development - `internal`: A private header-only component that hosts the code shared between benchmarks and tests. - `scripts`: Helpful scripts for development - - `src`: Compiled APIs and template specializations for the shared libraries + - `src`: Compiled APIs and template instantiations for the shared libraries - `template`: A skeleton template containing the bare-bones file structure and cmake configuration for writing applications with RAFT. - `test`: Googletests source code - `docs`: Source code and scripts for building library documentation (Uses breath, doxygen, & pydocs) diff --git a/docs/source/build.md b/docs/source/build.md index 262c5703bc..bd2afe6638 100644 --- a/docs/source/build.md +++ b/docs/source/build.md @@ -4,7 +4,7 @@ The easiest way to install RAFT is through conda and several packages are provided. - `libraft-headers` RAFT headers -- `libraft` (optional) shared library containing pre-compiled template specializations and runtime API. +- `libraft` (optional) shared library containing pre-compiled template instantiations and runtime API. - `pylibraft` (optional) Python wrappers around RAFT algorithms and primitives. - `raft-dask` (optional) enables deployment of multi-node multi-GPU algorithms that use RAFT `raft::comms` in Dask clusters. @@ -276,15 +276,7 @@ If the RAFT headers have already been installed into your environment with cmake Use `find_package(raft COMPONENTS compiled distributed)` to enable the shared library and transitively pass dependencies through separate targets for each component. In this example, the `raft::compiled` and `raft::distributed` targets will be available for configuring linking paths in addition to `raft::raft`. These targets will also pass through any transitive dependencies (such as NCCL for the `distributed` component). -The pre-compiled libraries contain template specializations for commonly used types, such as single- and double-precision floating-point. In order to use the symbols in the pre-compiled libraries, the compiler needs to be told not to instantiate templates that are already contained in the shared libraries. By convention, these header files are named `specializations.cuh` and located in the base directory for the packages that contain specializations. - -The following example tells the compiler to ignore the pre-compiled templates for the `raft::distance` API so any symbols already compiled into the `libraft` shared library will be used instead. RAFT's cmake creates a variable `RAFT_COMPILED` which can be used to ignore the pre-compiled template specializations only when the shared library has been enabled through cmake (such as by specifying the `compiled` component in `find_package`): -```c++ -#ifdef RAFT_COMPILED -#include -#include -#endif -``` +The pre-compiled libraries contain template instantiations for commonly used types, such as single- and double-precision floating-point. By default, these are used automatically when the `RAFT_COMPILED` macro is defined during compilation. This definition is automatically added by CMake. ### Building RAFT C++ from source in cmake diff --git a/docs/source/developer_guide.md b/docs/source/developer_guide.md index 6f57453e28..3f95cf0a01 100644 --- a/docs/source/developer_guide.md +++ b/docs/source/developer_guide.md @@ -260,6 +260,97 @@ Sometimes, we need to temporarily change the log pattern (eg: for reporting deci 4. Before creating a new primitive, check to see if one exists already. If one exists but the API isn't flexible enough to include your use-case, consider first refactoring the existing primitive. If that is not possible without an extreme number of changes, consider how the public API could be made more flexible. If the new primitive is different enough from all existing primitives, consider whether an existing public API could invoke the new primitive as an option or argument. If the new primitive is different enough from what exists already, add a header for the new public API function to the appropriate subdirectory and namespace. +## Header organization of expensive function templates + +RAFT is a heavily templated library. Several core functions are expensive to compile and we want to prevent duplicate compilation of this functionality. To limit build time, RAFT provides a precompiled library (libraft.so) where expensive function templates are instantiated for the most commonly used template parameters. To prevent (1) accidental instantiation of these templates and (2) unnecessary dependency on the internals of these templates, we use a split header structure and define macros to control template instantiation. This section describes the macros and header structure. + +**Macros.** We define the macros `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY`. The `RAFT_COMPILED` macro is defined by `CMake` when compiling code that (1) is part of `libraft.so` or (2) is linked with `libraft.so`. It indicates that a precompiled `libraft.so` is present at runtime. + +The `RAFT_EXPLICIT_INSTANTIATE_ONLY` macro is defined by `CMake` during compilation of `libraft.so` itself. When defined, it indicates that implicit instantiations of expensive function templates are forbidden (they result in a compiler error). In the RAFT project, we additionally define this macro during compilation of the tests and benchmarks. + +Below, we summarize which combinations of `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY` are used in practice and what the effect of the combination is. + +| RAFT_COMPILED | RAFT_EXPLICIT_INSTANTIATE_ONLY | Which targets | +|---------------|--------------------------------|------------------------------------------------------------------------------------------------------| +| defined | defined | `raft::compiled`, RAFT tests, RAFT benchmarks | +| defined | | Downstream libraries depending of `libraft` like cuML, cuGraph. | +| | | Downstream libraries depending on `libraft-headers` like cugraph-ops. | + + +| RAFT_COMPILED | RAFT_EXPLICIT_INSTANTIATE_ONLY | Effect | +|---------------|--------------------------------|-------------------------------------------------------------------------------------------------------| +| defined | defined | Templates are precompiled. Compiler error on accidental instantiation of expensive function template. | +| defined | | Templates are precompiled. Implicit instantiation allowed. | +| | | Nothing precompiled. Implicit instantiation allowed. | +| | defined | Avoid this: nothing precompiled. Compiler error on any instantiation of expensive function template. | + + + +**Header organization.** Any header file that defines an expensive function template (say `expensive.cuh`) should be split in three parts: `expensive.cuh`, `expensive-inl.cuh`, and `expensive-ext.cuh`. The file `expensive-inl.cuh` ("inl" for "inline") contains the template definitions, i.e., the actual code. The file `expensive.cuh` includes one or both of the other two files, depending on the values of the `RAFT_COMPILED` and `RAFT_EXPLICIT_INSTANTIATE_ONLY` macros. The file `expensive-ext.cuh` contains `extern template` instantiations. In addition, if `RAFT_EXPLICIT_INSTANTIATE_ONLY` is set, it contains template definitions to ensure that a compiler error is raised in case of accidental instantiation. + +The dispatching by `expensive.cuh` is performed as follows: +``` c++ +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +// If implicit instantiation is allowed, include template definitions. +#include "expensive-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +// Include extern template instantiations when RAFT is compiled. +#include "expensive-ext.cuh" +#endif +``` + +The file `expensive-inl.cuh` is unchanged: +``` c++ +namespace raft { +template +void expensive(T arg) { + // .. function body +} +} // namespace raft +``` + +The file `expensive-ext.cuh` contains the following: +``` c++ +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY +namespace raft { +// (1) define templates to raise an error in case of accidental instantiation +template void expensive(T arg) RAFT_EXPLICIT; +} // namespace raft +#endif //RAFT_EXPLICIT_INSTANTIATE_ONLY + +// (2) Provide extern template instantiations. +extern template void raft::expensive(int); +extern template void raft::expensive(float); +``` + +This header has two responsibilities: (1) define templates to raise an error in case of accidental instantiation and (2) provide `extern template` instantiations. +First, if `RAFT_EXPLICIT_INSTANTIATE_ONLY` is set, `expensive` is defined. This is done for two reasons: (1) to give a definition, because the definition in `expensive-inl.cuh` was skipped and (2) to indicate that the template should be explicitly instantiated by taging it with the `RAFT_EXPLICIT` macro. This macro defines the function body, and it ensures that an informative error message is generated when an implicit instantiation erroneously occurs. Finally, the `extern template` instantiations are listed. + +To actually generate the code for the template instances, the file `src/expensive.cu` contains the following. Note that the only difference between the extern template instantiations in `expensive-ext.cuh` and these lines are the removal of the word `extern`: + +``` c++ +#include + +template void raft::expensive(int); +template void raft::expensive(float); +``` + +**Design considerations**: + +1. In the `-ext.cuh` header, do not include implementation headers. Only include function parameter types and types that are used to instantiate the templates. If a primitive takes custom parameter types, define them in a separate header called `_types.hpp`. + +2. Keep docstrings in the `-inl.cuh` header, as it is closer to the code. Remove docstrings from template definitions in the `-ext.cuh` header. + +3. The order of inclusion in `expensive.cuh` is extremely important. If `RAFT_EXPLICIT_INSTANTIATE_ONLY` is not defined, but `RAFT_COMPILED` is defined, then we must include the template definitions before the `extern template` instantiations. + +4. If a header file defines multiple expensive templates, it can be that one of them is not instantiated. In this case, **do define** the template with `RAFT_EXPLICIT` in the `-ext` header. This way, when the template is instantiated, the developer gets a helpful error message instead of a confusing "function not found". + +This header structure was proposed in [issue #1416](https://github.com/rapidsai/raft/issues/1416), which contains more background on the motivation of this structure and the mechanics of C++ template instantiation. + ## Testing It's important for RAFT to maintain a high test coverage of the public APIs in order to minimize the potential for downstream projects to encounter unexpected build or runtime behavior as a result of changes. diff --git a/docs/source/using_libraft.md b/docs/source/using_libraft.md index f4f966f2c8..c28fadab46 100644 --- a/docs/source/using_libraft.md +++ b/docs/source/using_libraft.md @@ -1,59 +1,64 @@ # Using The Pre-Compiled Binary -At its core, RAFT is a header-only template library, which makes it very powerful in that APIs can be called with various different combinations of data types and only the templates which are actually used will be compiled into your binaries. This increased flexibility comes with a drawback that all the APIs need to be declared inline and thus calls which are made frequently in your code could be compiled again each source file for which they are invoked. +At its core, RAFT is a header-only template library, which makes it very powerful in that APIs can be called with various different combinations of data types and only the templates which are actually used will be compiled into your binaries. This increased flexibility comes with a drawback that all the APIs need to be declared inline and thus calls which are made frequently in your code could be compiled again in each source file for which they are invoked. -For most functions, this overhead is pretty minimal and not noticeable but some of RAFT's APIs consist of very complex hierarchies of function calls that ultimately end up dispatching to device code that's executed on the GPU. The compile times for these APIs may still be bearable when compiling for only a single compute architecture but could end up becoming extremely slow to compile for all of the supported architectures at once. +For most functions, compile-time overhead is minimal but some of RAFT's APIs take a substantial time to compile. As a rule of thumb, most functionality in `raft::distance`, `raft::neighbors`, and `raft::spatial` is expensive to compile and most functionality in other namespaces has little compile-time overhead. -There are three ways to solve this problem and speed up compile times: -1. Continue to use RAFT as a header-only library and create a CUDA source file in your project to explicitly instantiate the templates which are slow to compile. This can be tedious and will still require compiling the slow code at least once, but it's the most flexible option if you are using types that aren't already compiled into `libraft` -2. If you are able to use one of the template types that are already being compiled into `libraft`, you can use the pre-compiled template specializations, which I will describe in more detail in the following section. -3. If you would like to use RAFT but either cannot or would prefer not to compile any CUDA code yourself, you can simply add `libraft` to your link libraries and use the growing set of runtime APIs. +There are three ways to speed up compile times: -## Using Template Specializations +1. Continue to use RAFT as a header-only library and create a CUDA source file + in your project to explicitly instantiate the templates which are slow to + compile. This can be tedious and will still require compiling the slow code + at least once, but it's the most flexible option if you are using types that + aren't already compiled into `libraft` -As mentioned above, the pre-compiled template instantiations can save a lot of time if you are able to use the type combinations for the templates which are already specialized in the `libraft` binary. This will, of course, mean that you will need to add `libraft` to your link libraries. +2. If you are able to use one of the template types that are already being + compiled into `libraft`, you can use the pre-compiled template + instantiations, which are described in more detail in the following section. -At the top level of each namespace containing pre-compiled template specializations is a header file called `specializations.cuh`. This header file includes `extern template` directives for all the specializations which are compiled into libraft. As an example, including `raft/neighbors/specializations.cuh` in one of your source files will effectively tell the compiler to skip over any of the template specializations that are already compiled into the `libraft` binary. +3. If you would like to use RAFT but either cannot or would prefer not to + compile any CUDA code yourself, you can simply add `libraft` to your link + libraries and use the growing set of runtime APIs. -### How do I verify template specializations didn't compile into my binary? +### How do I verify template instantiations didn't compile into my binary? -Which specializations were chosen to instantiations were based on compile time analysis and reuse. This means you can't assume that all specializations are for the public API itself. Take the following example in `raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh`: +To verify that you are not accidentally instantiating templates that have not been pre-compiled in RAFT, set the `RAFT_EXPLICIT_INSTANTIATE_ONLY` macro. This only works if you are linking with the pre-compiled libraft (i.e., when `RAFT_COMPILED` has been defined). To check if, for instance, `raft::distance::distance` has been precompiled with specific template arguments, you can set `RAFT_EXPLICIT_INSTANTIATE_ONLY` at the top of the file you are compiling, as in the following example: ```c++ -namespace raft::neighbors::ivf_pq::detail { - -namespace { -using fp8s_t = fp_8bit<5, true>; -using fp8u_t = fp_8bit<5, false>; -} // namespace - -#define RAFT_INST(OutT, LutT) \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; - -#define RAFT_INST_ALL_OUT_T(LutT) \ - RAFT_INST(float, LutT) \ - RAFT_INST(half, LutT) - -RAFT_INST_ALL_OUT_T(float) -RAFT_INST_ALL_OUT_T(half) -RAFT_INST_ALL_OUT_T(fp8s_t) -RAFT_INST_ALL_OUT_T(fp8u_t) - -#undef RAFT_INST -#undef RAFT_INST_ALL_OUT_T - -} // namespace raft::neighbors::ivf_pq::detail -``` -We can see here that the function `raft::neighbors::ivf_pq::detail::get_compute_similarity_kernel` is being instantiated for the cartesian product of `OutT={float, half, fp8s_t, fp8u_t}` and `LutT={float, half}`. After linking against the `libraft` binary and including `raft/neighbors/specializations.cuh` in your source file, you can invoke the `raft::neighbors::ivf_pq` functions and compile your code. If the specializations are working, you should be able to use `nm -g -C --defined-only /path/to/your/binary | grep raft::neighbors::ivf_pq::detail::get_compute_similarity::kernel` and you shouldn't see any results, because those symbols should be coming from the `libraft` binary and skipped from compiling into your binary. +#ifdef RAFT_COMPILED +#define RAFT_EXPLICIT_INSTANTIATE_ONLY +#endif + +#include +#include +#include + +int main() +{ + raft::resources handle{}; + + // Change IdxT to uint64_t and you will get an error because you are + // instantiating a template that has not been pre-compiled. + using IdxT = int; + + const float* x = nullptr; + const float* y = nullptr; + float* out = nullptr; + int m = 1024; + int n = 1024; + int k = 1024; + bool row_major = true; + raft::distance::distance( + handle, x, y, out, m, n, k, row_major, 2.0f); +} +``` ## Runtime APIs -RAFT contains a growing list of runtime APIs that, unlike the pre-compiled template specializations, allow you to link against `libraft` and invoke RAFT directly from `cpp` files. The benefit to RAFT's runtime APIs are two-fold- unlike the template specializations, which still require your code be compiled with the CUDA compiler (`nvcc`), the `runtime` APIs are the lightweight wrappers which enable `pylibraft`. +RAFT contains a growing list of runtime APIs that, unlike the pre-compiled +template instantiations, allow you to link against `libraft` and invoke RAFT +directly from `cpp` files. The benefit to RAFT's runtime APIs is that they can +be used from code that is compiled with a `c++` compiler (rather than the CUDA +compiler `nvcc`). This enables the `runtime` APIs to power `pylibraft`. -Similar to the pre-compiled template specializations, RAFT's runtime APIs \ No newline at end of file From e6bb5d5b957df2cf2b053be52b8359567374cd94 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:47:47 +0200 Subject: [PATCH 10/32] Replace specializations by split headers --- cpp/CMakeLists.txt | 223 ++-- .../detail/pairwise_matrix/dispatch-ext.cuh | 194 +++ .../detail/pairwise_matrix/dispatch.cuh | 24 + cpp/include/raft/distance/distance-ext.cuh | 1065 +++++++++++++++++ cpp/include/raft/distance/distance.cuh | 24 + cpp/include/raft/distance/fused_l2_nn-ext.cuh | 88 ++ cpp/include/raft/distance/fused_l2_nn.cuh | 25 + .../linalg/detail/coalesced_reduction-ext.cuh | 73 ++ .../linalg/detail/coalesced_reduction.cuh} | 12 +- .../raft/matrix/detail/select_k-ext.cuh | 65 + .../raft/matrix/detail/select_k.cuh} | 11 +- cpp/include/raft/neighbors/ball_cover-ext.cuh | 124 ++ cpp/include/raft/neighbors/ball_cover.cuh | 24 + .../raft/neighbors/brute_force-ext.cuh | 109 ++ cpp/include/raft/neighbors/brute_force.cuh | 25 + .../neighbors/detail/ivf_flat_search-ext.cuh | 58 + .../neighbors/detail/ivf_flat_search.cuh} | 11 +- .../neighbors/detail/selection_faiss-ext.cuh | 60 + .../raft/neighbors/detail/selection_faiss.cuh | 25 + cpp/include/raft/neighbors/ivf_flat-ext.cuh | 185 +++ .../raft/neighbors/ivf_flat.cuh} | 11 +- cpp/include/raft/neighbors/ivf_pq-ext.cuh | 170 +++ .../raft/neighbors/ivf_pq.cuh} | 11 +- cpp/include/raft/neighbors/refine-ext.cuh | 78 ++ cpp/include/raft/neighbors/refine.cuh | 25 + .../knn/detail/ball_cover/registers-ext.cuh | 129 ++ .../knn/detail/ball_cover/registers.cuh | 25 + .../spatial/knn/detail/fused_l2_knn-ext.cuh | 70 ++ .../raft/spatial/knn/detail/fused_l2_knn.cuh | 24 + .../pairwise_matrix/dispatch_00_generate.py | 194 +++ ...patch_canberra_double_double_double_int.cu | 55 + ...dispatch_canberra_float_float_float_int.cu | 50 + ...ch_correlation_double_double_double_int.cu | 55 + ...patch_correlation_float_float_float_int.cu | 55 + ...ispatch_cosine_double_double_double_int.cu | 51 + .../dispatch_cosine_float_float_float_int.cu | 51 + ...ing_unexpanded_double_double_double_int.cu | 50 + ...amming_unexpanded_float_float_float_int.cu | 50 + ...inger_expanded_double_double_double_int.cu | 55 + ...ellinger_expanded_float_float_float_int.cu | 50 + ...jensen_shannon_double_double_double_int.cu | 55 + ...ch_jensen_shannon_float_float_float_int.cu | 55 + ..._kl_divergence_double_double_double_int.cu | 50 + ...tch_kl_divergence_float_float_float_int.cu | 50 + .../dispatch_l1_double_double_double_int.cu | 50 + .../dispatch_l1_float_float_float_int.cu | 50 + ...ch_l2_expanded_double_double_double_int.cu | 51 + ...patch_l2_expanded_float_float_float_int.cu | 51 + ..._l2_unexpanded_double_double_double_int.cu | 55 + ...tch_l2_unexpanded_float_float_float_int.cu | 50 + ...dispatch_l_inf_double_double_double_int.cu | 50 + .../dispatch_l_inf_float_float_float_int.cu | 50 + ..._lp_unexpanded_double_double_double_int.cu | 55 + ...tch_lp_unexpanded_float_float_float_int.cu | 50 + .../detail/pairwise_matrix/dispatch_rbf.cu | 64 + ...tch_russel_rao_double_double_double_int.cu | 55 + ...spatch_russel_rao_float_float_float_int.cu | 50 + cpp/src/distance/distance.cu | 934 +++++++++++++++ cpp/src/distance/fused_l2_nn.cu | 54 + .../detail/00_write_template.py | 159 --- .../canberra_double_double_double_int.cu | 33 - .../detail/canberra_float_float_float_int.cu | 33 - .../correlation_double_double_double_int.cu | 33 - .../correlation_float_float_float_int.cu | 33 - .../detail/cosine_double_double_double_int.cu | 34 - .../detail/cosine_float_float_float_int.cu | 34 - ...ing_unexpanded_double_double_double_int.cu | 33 - ...amming_unexpanded_float_float_float_int.cu | 33 - ...inger_expanded_double_double_double_int.cu | 33 - ...ellinger_expanded_float_float_float_int.cu | 33 - .../inner_product_double_double_double_int.cu | 38 - .../inner_product_float_float_float_int.cu | 37 - ...jensen_shannon_double_double_double_int.cu | 34 - .../jensen_shannon_float_float_float_int.cu | 34 - .../detail/kernels/gram_matrix_base_double.cu | 20 - .../kernels/polynomial_kernel_double_int.cu | 20 - .../detail/kernels/tanh_kernel_double.cu | 20 - .../kl_divergence_double_double_double_int.cu | 33 - .../kl_divergence_float_float_float_int.cu | 33 - .../detail/l1_double_double_double_int.cu | 33 - .../detail/l1_float_float_float_int.cu | 33 - .../l2_expanded_double_double_double_int.cu | 34 - .../l2_expanded_float_float_float_int.cu | 34 - .../l2_unexpanded_double_double_double_int.cu | 33 - .../l2_unexpanded_float_float_float_int.cu | 33 - .../detail/l_inf_double_double_double_int.cu | 33 - .../detail/l_inf_float_float_float_int.cu | 33 - .../lp_unexpanded_double_double_double_int.cu | 33 - .../lp_unexpanded_float_float_float_int.cu | 33 - .../russel_rao_double_double_double_int.cu | 33 - .../russel_rao_float_float_float_int.cu | 33 - .../specializations/fused_l2_nn_double_int.cu | 51 - .../fused_l2_nn_double_int64.cu | 51 - .../specializations/fused_l2_nn_float_int.cu | 51 - .../fused_l2_nn_float_int64.cu | 51 - cpp/src/linalg/detail/coalesced_reduction.cu | 69 ++ .../matrix/detail/select_k_double_int64_t.cu | 33 + .../matrix/detail/select_k_double_uint32_t.cu | 34 + .../matrix/detail/select_k_float_int64_t.cu | 33 + .../matrix/detail/select_k_float_uint32_t.cu | 33 + .../matrix/detail/select_k_half_int64_t.cu | 33 + .../matrix/detail/select_k_half_uint32_t.cu | 33 + .../detail/select_k_float_int64_t.cu | 36 - .../detail/select_k_float_uint32_t.cu | 36 - .../detail/select_k_half_int64_t.cu | 36 - .../detail/select_k_half_uint32_t.cu | 36 - cpp/src/neighbors/ball_cover.cu | 66 + cpp/src/neighbors/brute_force_00_generate.py | 106 ++ .../brute_force_fused_l2_knn_float_int64_t.cu | 45 + .../brute_force_knn_int64_t_float_int64_t.cu | 47 + .../brute_force_knn_int64_t_float_uint32_t.cu | 47 + .../brute_force_knn_int_float_int.cu | 47 + ...brute_force_knn_uint32_t_float_uint32_t.cu | 47 + cpp/src/neighbors/detail/ivf_flat_search.cu | 35 + .../detail/selection_faiss_00_generate.py | 75 ++ .../detail/selection_faiss_int32_t_float.cu | 44 + .../detail/selection_faiss_int_double.cu | 44 + .../detail/selection_faiss_long_float.cu | 44 + .../detail/selection_faiss_size_t_double.cu | 44 + .../detail/selection_faiss_size_t_float.cu | 44 + .../detail/selection_faiss_uint32_t_float.cu | 44 + cpp/src/neighbors/ivf_flat_00_generate.py | 148 +++ .../neighbors/ivf_flat_build_float_int64_t.cu | 50 + .../ivf_flat_build_int8_t_int64_t.cu | 50 + .../ivf_flat_build_uint8_t_int64_t.cu | 50 + .../ivf_flat_extend_float_int64_t.cu | 58 + .../ivf_flat_extend_int8_t_int64_t.cu | 58 + .../ivf_flat_extend_uint8_t_int64_t.cu | 58 + .../ivf_flat_search_float_int64_t.cu | 49 + .../ivf_flat_search_int8_t_int64_t.cu | 49 + .../ivf_flat_search_uint8_t_int64_t.cu | 49 + .../neighbors/ivfpq_build_float_int64_t.cu | 36 + .../neighbors/ivfpq_build_int8_t_int64_t.cu | 36 + .../neighbors/ivfpq_build_uint8_t_int64_t.cu | 36 + .../neighbors/ivfpq_extend_float_int64_t.cu | 50 + .../neighbors/ivfpq_extend_int8_t_int64_t.cu | 50 + .../neighbors/ivfpq_extend_uint8_t_int64_t.cu | 50 + .../neighbors/ivfpq_search_float_int64_t.cu | 42 + .../neighbors/ivfpq_search_int8_t_int64_t.cu | 42 + .../neighbors/ivfpq_search_uint8_t_int64_t.cu | 42 + cpp/src/neighbors/refine_00_generate.py | 78 ++ cpp/src/neighbors/refine_float_float.cu | 50 + cpp/src/neighbors/refine_int8_t_float.cu | 50 + cpp/src/neighbors/refine_uint8_t_float.cu | 50 + .../ball_cover_all_knn_query.cu | 33 - .../specializations/ball_cover_build_index.cu | 31 - .../specializations/ball_cover_knn_query.cu | 34 - .../detail/ball_cover_lowdim_pass_one_2d.cu | 43 - .../detail/ball_cover_lowdim_pass_one_3d.cu | 43 - .../detail/ball_cover_lowdim_pass_two_2d.cu | 41 - .../detail/ball_cover_lowdim_pass_two_3d.cu | 42 - .../brute_force_knn_impl_long_float_int.cu | 39 - .../brute_force_knn_impl_long_float_uint.cu | 39 - .../brute_force_knn_impl_uint_float_int.cu | 39 - .../brute_force_knn_impl_uint_float_uint.cu | 39 - .../compute_similarity_float_float_fast.cu | 26 - ...pute_similarity_float_float_no_basediff.cu | 27 - ...pute_similarity_float_float_no_smem_lut.cu | 27 - .../compute_similarity_float_fp8s_fast.cu | 27 - ...mpute_similarity_float_fp8s_no_basediff.cu | 28 - ...mpute_similarity_float_fp8s_no_smem_lut.cu | 28 - .../compute_similarity_float_fp8u_fast.cu | 28 - ...mpute_similarity_float_fp8u_no_basediff.cu | 28 - ...mpute_similarity_float_fp8u_no_smem_lut.cu | 28 - .../compute_similarity_float_half_fast.cu | 27 - ...mpute_similarity_float_half_no_basediff.cu | 27 - ...mpute_similarity_float_half_no_smem_lut.cu | 27 - .../compute_similarity_half_fp8s_fast.cu | 27 - ...ompute_similarity_half_fp8s_no_basediff.cu | 27 - ...ompute_similarity_half_fp8s_no_smem_lut.cu | 27 - .../compute_similarity_half_fp8u_fast.cu | 27 - ...ompute_similarity_half_fp8u_no_basediff.cu | 28 - ...ompute_similarity_half_fp8u_no_smem_lut.cu | 28 - .../compute_similarity_half_half_fast.cu | 27 - ...ompute_similarity_half_half_no_basediff.cu | 27 - ...ompute_similarity_half_half_no_smem_lut.cu | 27 - ...mpute_similarity_float_half_no_smem_lut.cu | 27 - .../fused_l2_knn_int_float_false.cu | 42 - .../fused_l2_knn_int_float_true.cu | 41 - .../fused_l2_knn_long_float_false.cu | 41 - .../fused_l2_knn_long_float_true.cu | 41 - .../ivfflat_build_float_int64_t.cu | 31 - .../ivfflat_build_int8_t_int64_t.cu | 31 - .../ivfflat_build_uint8_t_int64_t.cu | 31 - .../ivfflat_extend_float_int64_t.cu | 37 - .../ivfflat_extend_int8_t_int64_t.cu | 37 - .../ivfflat_extend_uint8_t_int64_t.cu | 37 - .../ivfflat_search_float_int64_t.cu | 58 - .../ivfflat_search_int8_t_int64_t.cu | 49 - .../ivfflat_search_uint8_t_int64_t.cu | 49 - .../ivfpq_build_float_int64_t.cu | 32 - .../ivfpq_build_int8_t_int64_t.cu | 32 - .../ivfpq_build_uint8_t_int64_t.cu | 32 - .../ivfpq_extend_float_int64_t.cu | 39 - .../ivfpq_extend_int8_t_int64_t.cu | 39 - .../ivfpq_extend_uint8_t_int64_t.cu | 39 - .../ivfpq_search_float_int64_t.cu | 34 - .../ivfpq_search_int8_t_int64_t.cu | 34 - .../ivfpq_search_uint8_t_int64_t.cu | 34 - .../specializations/refine_d_int64_t_float.cu | 31 - .../refine_d_int64_t_int8_t.cu | 31 - .../refine_d_int64_t_uint8_t.cu | 31 - .../specializations/refine_h_int64_t_float.cu | 31 - .../refine_h_int64_t_int8_t.cu | 30 - .../refine_h_int64_t_uint8_t.cu | 31 - .../brute_force_knn_long_float_int.cu | 42 - .../brute_force_knn_long_float_uint.cu | 42 - .../brute_force_knn_uint32_t_float_int.cu | 41 - .../brute_force_knn_uint32_t_float_uint.cu | 42 - .../knn/detail/ball_cover/registers.cu | 60 + .../ball_cover/registers_00_generate.py | 112 ++ .../ball_cover/registers_pass_one_2d_dist.cu | 48 + .../registers_pass_one_2d_euclidean.cu | 48 + .../registers_pass_one_2d_haversine.cu | 48 + .../ball_cover/registers_pass_one_3d_dist.cu | 48 + .../registers_pass_one_3d_euclidean.cu | 48 + .../registers_pass_one_3d_haversine.cu | 48 + .../ball_cover/registers_pass_two_2d_dist.cu | 48 + .../registers_pass_two_2d_euclidean.cu | 48 + .../registers_pass_two_2d_haversine.cu | 48 + .../ball_cover/registers_pass_two_3d_dist.cu | 48 + .../registers_pass_two_3d_euclidean.cu | 48 + .../registers_pass_two_3d_haversine.cu | 48 + .../knn/detail/fused_l2_knn_int32_t_float.cu | 40 + .../knn/detail/fused_l2_knn_int64_t_float.cu | 40 + .../knn/detail/fused_l2_knn_uint32_t_float.cu | 41 + 226 files changed, 8560 insertions(+), 3871 deletions(-) create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh create mode 100644 cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh create mode 100644 cpp/include/raft/distance/distance-ext.cuh create mode 100644 cpp/include/raft/distance/distance.cuh create mode 100644 cpp/include/raft/distance/fused_l2_nn-ext.cuh create mode 100644 cpp/include/raft/distance/fused_l2_nn.cuh create mode 100644 cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh rename cpp/{src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu => include/raft/linalg/detail/coalesced_reduction.cuh} (67%) create mode 100644 cpp/include/raft/matrix/detail/select_k-ext.cuh rename cpp/{src/distance/specializations/detail/kernels/rbf_kernel_float.cu => include/raft/matrix/detail/select_k.cuh} (78%) create mode 100644 cpp/include/raft/neighbors/ball_cover-ext.cuh create mode 100644 cpp/include/raft/neighbors/ball_cover.cuh create mode 100644 cpp/include/raft/neighbors/brute_force-ext.cuh create mode 100644 cpp/include/raft/neighbors/brute_force.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh rename cpp/{src/distance/specializations/detail/kernels/gram_matrix_base_float.cu => include/raft/neighbors/detail/ivf_flat_search.cuh} (78%) create mode 100644 cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh create mode 100644 cpp/include/raft/neighbors/detail/selection_faiss.cuh create mode 100644 cpp/include/raft/neighbors/ivf_flat-ext.cuh rename cpp/{src/distance/specializations/detail/kernels/tanh_kernel_float.cu => include/raft/neighbors/ivf_flat.cuh} (78%) create mode 100644 cpp/include/raft/neighbors/ivf_pq-ext.cuh rename cpp/{src/distance/specializations/detail/kernels/rbf_kernel_double.cu => include/raft/neighbors/ivf_pq.cuh} (78%) create mode 100644 cpp/include/raft/neighbors/refine-ext.cuh create mode 100644 cpp/include/raft/neighbors/refine.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh create mode 100644 cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu create mode 100644 cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu create mode 100644 cpp/src/distance/distance.cu create mode 100644 cpp/src/distance/fused_l2_nn.cu delete mode 100644 cpp/src/distance/specializations/detail/00_write_template.py delete mode 100644 cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu delete mode 100644 cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu delete mode 100644 cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l1_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l1_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu delete mode 100644 cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu delete mode 100644 cpp/src/distance/specializations/fused_l2_nn_double_int.cu delete mode 100644 cpp/src/distance/specializations/fused_l2_nn_double_int64.cu delete mode 100644 cpp/src/distance/specializations/fused_l2_nn_float_int.cu delete mode 100644 cpp/src/distance/specializations/fused_l2_nn_float_int64.cu create mode 100644 cpp/src/linalg/detail/coalesced_reduction.cu create mode 100644 cpp/src/matrix/detail/select_k_double_int64_t.cu create mode 100644 cpp/src/matrix/detail/select_k_double_uint32_t.cu create mode 100644 cpp/src/matrix/detail/select_k_float_int64_t.cu create mode 100644 cpp/src/matrix/detail/select_k_float_uint32_t.cu create mode 100644 cpp/src/matrix/detail/select_k_half_int64_t.cu create mode 100644 cpp/src/matrix/detail/select_k_half_uint32_t.cu delete mode 100644 cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu delete mode 100644 cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu delete mode 100644 cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu delete mode 100644 cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu create mode 100644 cpp/src/neighbors/ball_cover.cu create mode 100644 cpp/src/neighbors/brute_force_00_generate.py create mode 100644 cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu create mode 100644 cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu create mode 100644 cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu create mode 100644 cpp/src/neighbors/brute_force_knn_int_float_int.cu create mode 100644 cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu create mode 100644 cpp/src/neighbors/detail/ivf_flat_search.cu create mode 100644 cpp/src/neighbors/detail/selection_faiss_00_generate.py create mode 100644 cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu create mode 100644 cpp/src/neighbors/detail/selection_faiss_int_double.cu create mode 100644 cpp/src/neighbors/detail/selection_faiss_long_float.cu create mode 100644 cpp/src/neighbors/detail/selection_faiss_size_t_double.cu create mode 100644 cpp/src/neighbors/detail/selection_faiss_size_t_float.cu create mode 100644 cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu create mode 100644 cpp/src/neighbors/ivf_flat_00_generate.py create mode 100644 cpp/src/neighbors/ivf_flat_build_float_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_search_float_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_build_float_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_extend_float_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_search_float_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu create mode 100644 cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu create mode 100644 cpp/src/neighbors/refine_00_generate.py create mode 100644 cpp/src/neighbors/refine_float_float.cu create mode 100644 cpp/src/neighbors/refine_int8_t_float.cu create mode 100644 cpp/src/neighbors/refine_uint8_t_float.cu delete mode 100644 cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu delete mode 100644 cpp/src/neighbors/specializations/ball_cover_build_index.cu delete mode 100644 cpp/src/neighbors/specializations/ball_cover_knn_query.cu delete mode 100644 cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu delete mode 100644 cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu delete mode 100644 cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu delete mode 100644 cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu delete mode 100644 cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu delete mode 100644 cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu delete mode 100644 cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu delete mode 100644 cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu delete mode 100644 cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu delete mode 100644 cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu delete mode 100644 cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu delete mode 100644 cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu delete mode 100644 cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu delete mode 100644 cpp/src/neighbors/specializations/refine_d_int64_t_float.cu delete mode 100644 cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu delete mode 100644 cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu delete mode 100644 cpp/src/neighbors/specializations/refine_h_int64_t_float.cu delete mode 100644 cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu delete mode 100644 cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu delete mode 100644 cpp/src/nn/specializations/brute_force_knn_long_float_int.cu delete mode 100644 cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu delete mode 100644 cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu delete mode 100644 cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu create mode 100644 cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu create mode 100644 cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu create mode 100644 cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu create mode 100644 cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 2c2e5370cf..55ad024a56 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -263,141 +263,79 @@ set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) add_library( raft_lib - src/neighbors/specializations/refine_d_int64_t_float.cu - src/neighbors/specializations/refine_d_int64_t_int8_t.cu - src/neighbors/specializations/refine_d_int64_t_uint8_t.cu - src/neighbors/specializations/refine_h_int64_t_float.cu - src/neighbors/specializations/refine_h_int64_t_int8_t.cu - src/neighbors/specializations/refine_h_int64_t_uint8_t.cu - src/distance/specializations/detail/canberra_double_double_double_int.cu - src/distance/specializations/detail/canberra_float_float_float_int.cu - src/distance/specializations/detail/correlation_double_double_double_int.cu - src/distance/specializations/detail/correlation_float_float_float_int.cu - src/distance/specializations/detail/cosine_double_double_double_int.cu - src/distance/specializations/detail/cosine_float_float_float_int.cu - src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu - src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu - src/distance/specializations/detail/inner_product_float_float_float_int.cu - src/distance/specializations/detail/inner_product_double_double_double_int.cu - src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu - src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu - src/distance/specializations/detail/kernels/gram_matrix_base_double.cu - src/distance/specializations/detail/kernels/gram_matrix_base_float.cu - src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu - src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu - # These are somehow missing a kernel definition which is causing a compile error. - # src/distance/specializations/detail/kernels/rbf_kernel_double.cu - # src/distance/specializations/detail/kernels/rbf_kernel_float.cu - src/distance/specializations/detail/kernels/tanh_kernel_double.cu - src/distance/specializations/detail/kernels/tanh_kernel_float.cu - src/distance/specializations/detail/kl_divergence_float_float_float_int.cu - src/distance/specializations/detail/kl_divergence_double_double_double_int.cu - src/distance/specializations/detail/l1_float_float_float_int.cu - src/distance/specializations/detail/l1_double_double_double_int.cu - src/distance/specializations/detail/l2_expanded_float_float_float_int.cu - src/distance/specializations/detail/l2_expanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/l_inf_double_double_double_int.cu - src/distance/specializations/detail/l_inf_float_float_float_int.cu - src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu - src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu - src/distance/specializations/detail/russel_rao_double_double_double_int.cu - src/distance/specializations/detail/russel_rao_float_float_float_int.cu - src/distance/specializations/fused_l2_nn_double_int.cu - src/distance/specializations/fused_l2_nn_double_int64.cu - src/distance/specializations/fused_l2_nn_float_int.cu - src/distance/specializations/fused_l2_nn_float_int64.cu - src/matrix/specializations/detail/select_k_float_uint32_t.cu - src/matrix/specializations/detail/select_k_float_int64_t.cu - src/matrix/specializations/detail/select_k_half_uint32_t.cu - src/matrix/specializations/detail/select_k_half_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu - src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu - src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu - src/neighbors/specializations/ball_cover_all_knn_query.cu - src/neighbors/specializations/ball_cover_build_index.cu - src/neighbors/specializations/ball_cover_knn_query.cu - src/neighbors/specializations/fused_l2_knn_long_float_true.cu - src/neighbors/specializations/fused_l2_knn_long_float_false.cu - src/neighbors/specializations/fused_l2_knn_int_float_true.cu - src/neighbors/specializations/fused_l2_knn_int_float_false.cu - src/neighbors/specializations/ivfflat_build_float_int64_t.cu - src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_float_int64_t.cu - src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_float_int64_t.cu - src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_float_int64_t.cu - src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_float_int64_t.cu - src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_float_int64_t.cu - src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu - src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu - src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu - src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu - src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu + src/distance/detail/pairwise_matrix/dispatch_rbf.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu + src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu + src/distance/distance.cu + src/distance/fused_l2_nn.cu + src/linalg/detail/coalesced_reduction.cu + src/matrix/detail/select_k_double_int64_t.cu + src/matrix/detail/select_k_double_uint32_t.cu + src/matrix/detail/select_k_float_int64_t.cu + src/matrix/detail/select_k_float_uint32_t.cu + src/matrix/detail/select_k_half_int64_t.cu + src/matrix/detail/select_k_half_uint32_t.cu + src/neighbors/ball_cover.cu + src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_int64_t.cu + src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu + src/neighbors/brute_force_knn_int_float_int.cu + src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu + src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu + src/neighbors/detail/ivf_flat_search.cu + src/neighbors/detail/selection_faiss_int32_t_float.cu + src/neighbors/detail/selection_faiss_int_double.cu + src/neighbors/detail/selection_faiss_long_float.cu + src/neighbors/detail/selection_faiss_size_t_double.cu + src/neighbors/detail/selection_faiss_size_t_float.cu + src/neighbors/detail/selection_faiss_uint32_t_float.cu + src/neighbors/ivf_flat_build_float_int64_t.cu + src/neighbors/ivf_flat_build_int8_t_int64_t.cu + src/neighbors/ivf_flat_build_uint8_t_int64_t.cu + src/neighbors/ivf_flat_extend_float_int64_t.cu + src/neighbors/ivf_flat_extend_int8_t_int64_t.cu + src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu + src/neighbors/ivf_flat_search_float_int64_t.cu + src/neighbors/ivf_flat_search_int8_t_int64_t.cu + src/neighbors/ivf_flat_search_uint8_t_int64_t.cu + src/neighbors/ivfpq_build_float_int64_t.cu + src/neighbors/ivfpq_build_int8_t_int64_t.cu + src/neighbors/ivfpq_build_uint8_t_int64_t.cu + src/neighbors/ivfpq_extend_float_int64_t.cu + src/neighbors/ivfpq_extend_int8_t_int64_t.cu + src/neighbors/ivfpq_extend_uint8_t_int64_t.cu + src/neighbors/ivfpq_search_float_int64_t.cu + src/neighbors/ivfpq_search_int8_t_int64_t.cu + src/neighbors/ivfpq_search_uint8_t_int64_t.cu + src/neighbors/refine_float_float.cu + src/neighbors/refine_int8_t_float.cu + src/neighbors/refine_uint8_t_float.cu src/raft_runtime/cluster/cluster_cost.cuh src/raft_runtime/cluster/cluster_cost_double.cu src/raft_runtime/cluster/cluster_cost_float.cu @@ -430,6 +368,21 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/random/rmat_rectangular_generator_int64_float.cu src/raft_runtime/random/rmat_rectangular_generator_int_double.cu src/raft_runtime/random/rmat_rectangular_generator_int_float.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu + src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu + src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu + src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu + src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu ) set_target_properties( raft_lib diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh new file mode 100644 index 0000000000..e1dc6f9b37 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-ext.cuh @@ -0,0 +1,194 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::identity_op +#include // ops::* +#include // ops::has_cutlass_op +#include // rbf_fin_op +#include // pairwise_matrix_params +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::distance::detail { + +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) RAFT_EXPLICIT; + +}; // namespace raft::distance::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + extern template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +/* + * Hierarchy of instantiations: + * + * This file defines extern template instantiations of the distance kernels. The + * instantiation of the public API is handled in raft/distance/distance-ext.cuh. + * + * After adding an instance here, make sure to also add the instance there. + */ + +// The following two instances are used in the RBF kernel object. Note the use of int64_t for the +// index type. +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +// Rest of instances +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh new file mode 100644 index 0000000000..31aebed3d0 --- /dev/null +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "dispatch-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "dispatch-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/distance-ext.cuh b/cpp/include/raft/distance/distance-ext.cuh new file mode 100644 index 0000000000..7171ba605f --- /dev/null +++ b/cpp/include/raft/distance/distance-ext.cuh @@ -0,0 +1,1065 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // raft::device_matrix_view +#include // raft::identity_op +#include // raft::resources +#include // rbf_fin_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT +#include // rmm::device_uvector + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + FinalLambda fin_op, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + size_t worksize, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) RAFT_EXPLICIT; + +template +size_t getWorkspaceSize(raft::device_matrix_view const& x, + raft::device_matrix_view const& y) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + const DataT* x, + const DataT* y, + OutT* dist, + IdxT m, + IdxT n, + IdxT k, + bool isRowMajor = true, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + rmm::device_uvector& workspace, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + const Type* x, + const Type* y, + Type* dist, + IdxT m, + IdxT n, + IdxT k, + raft::distance::DistanceType metric, + bool isRowMajor = true, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void distance(raft::resources const& handle, + raft::device_matrix_view const x, + raft::device_matrix_view const y, + raft::device_matrix_view dist, + DataT metric_arg = 2.0f) RAFT_EXPLICIT; + +template +void pairwise_distance(raft::resources const& handle, + device_matrix_view const x, + device_matrix_view const y, + device_matrix_view dist, + raft::distance::DistanceType metric, + Type metric_arg = 2.0f) RAFT_EXPLICIT; + +}; // namespace distance +}; // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +/* + * Hierarchy of instantiations: + * + * This file defines the extern template instantiations for the public API of + * raft::distance. To improve compile times, the extern template instantiation + * of the distance kernels is handled in + * distance/detail/pairwise_matrix/dispatch-ext.cuh. + * + * After adding an instance here, make sure to also add the instance to + * dispatch-ext.cuh and the corresponding .cu files. + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + extern template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + extern template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + extern template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/include/raft/distance/distance.cuh b/cpp/include/raft/distance/distance.cuh new file mode 100644 index 0000000000..7d5cc5d486 --- /dev/null +++ b/cpp/include/raft/distance/distance.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "distance-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "distance-ext.cuh" +#endif diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh new file mode 100644 index 0000000000..a0af04c4e8 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -0,0 +1,88 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t +#include // raft::device_resources +#include // raft::KeyValuePair +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft { +namespace distance { + +template +void initialize(raft::device_resources const& handle, + OutT* min, + IdxT m, + DataT maxVal, + ReduceOpT redOp) RAFT_EXPLICIT; + +template +void fusedL2NNMinReduce(OutT* min, + const DataT* x, + const DataT* y, + const DataT* xn, + const DataT* yn, + IdxT m, + IdxT n, + IdxT k, + void* workspace, + bool sqrt, + bool initOutBuffer, + cudaStream_t stream) RAFT_EXPLICIT; + +} // namespace distance +} // namespace raft + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + extern template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh new file mode 100644 index 0000000000..737d3fcb08 --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "fused_l2_nn-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "fused_l2_nn-ext.cuh" +#endif diff --git a/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh new file mode 100644 index 0000000000..4800f2e3cf --- /dev/null +++ b/cpp/include/raft/linalg/detail/coalesced_reduction-ext.cuh @@ -0,0 +1,73 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +// The explicit instantiation of raft::linalg::detail::coalescedReduction is not +// forced because there would be too many instances. Instead, we cover the most +// common instantiations with extern template instantiations below. + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + extern template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh similarity index 67% rename from cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu rename to cpp/include/raft/linalg/detail/coalesced_reduction.cuh index 6609de69ac..3e6b17978b 100644 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_float_int.cu +++ b/cpp/include/raft/linalg/detail/coalesced_reduction.cuh @@ -14,7 +14,13 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file +// Always include inline definitions of coalesced reduction, because we do not +// force explicit instantion. +#include "coalesced_reduction-inl.cuh" + +// Do include the extern template instantiations when possible. +#ifdef RAFT_COMPILED +#include "coalesced_reduction-ext.cuh" +#endif diff --git a/cpp/include/raft/matrix/detail/select_k-ext.cuh b/cpp/include/raft/matrix/detail/select_k-ext.cuh new file mode 100644 index 0000000000..2b233c156d --- /dev/null +++ b/cpp/include/raft/matrix/detail/select_k-ext.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uint32_t +#include // __half +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::matrix::detail { + +template +void select_k(const T* in_val, + const IdxT* in_idx, + size_t batch_size, + size_t len, + int k, + T* out_val, + IdxT* out_idx, + bool select_min, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; +} // namespace raft::matrix::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + extern template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); +instantiate_raft_matrix_detail_select_k(__half, int64_t); +instantiate_raft_matrix_detail_select_k(float, int64_t); +instantiate_raft_matrix_detail_select_k(float, uint32_t); +// We did not have these two for double before, but there are tests for them. We +// therefore include them here. +instantiate_raft_matrix_detail_select_k(double, int64_t); +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu b/cpp/include/raft/matrix/detail/select_k.cuh similarity index 78% rename from cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu rename to cpp/include/raft/matrix/detail/select_k.cuh index 423613dcd1..d011f23534 100644 --- a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_float.cu +++ b/cpp/include/raft/matrix/detail/select_k.cuh @@ -14,7 +14,12 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "select_k-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "select_k-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/ball_cover-ext.cuh b/cpp/include/raft/neighbors/ball_cover-ext.cuh new file mode 100644 index 0000000000..b6ab12d8e1 --- /dev/null +++ b/cpp/include/raft/neighbors/ball_cover-ext.cuh @@ -0,0 +1,124 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // uint32_t +#include // raft::distance::DistanceType +#include // BallCoverIndex +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ball_cover { + +template +void build_index(raft::device_resources const& handle, + BallCoverIndex& index) RAFT_EXPLICIT; + +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + int_t k, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void all_knn_query(raft::device_resources const& handle, + BallCoverIndex& index, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + int_t k, + const value_t* query, + int_t n_query_pts, + idx_t* inds, + value_t* dists, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +template +void knn_query(raft::device_resources const& handle, + const BallCoverIndex& index, + raft::device_matrix_view query, + raft::device_matrix_view inds, + raft::device_matrix_view dists, + int_t k, + bool perform_post_filtering = true, + float weight = 1.0) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ball_cover + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ + extern template void \ + raft::neighbors::ball_cover::build_index( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index); \ + \ + extern template void \ + raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void \ + raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + const value_t* query, \ + int_t n_query_pts, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + extern template void \ + raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); + +instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); + +#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/include/raft/neighbors/ball_cover.cuh b/cpp/include/raft/neighbors/ball_cover.cuh new file mode 100644 index 0000000000..82c56b64dd --- /dev/null +++ b/cpp/include/raft/neighbors/ball_cover.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ball_cover-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ball_cover-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/brute_force-ext.cuh b/cpp/include/raft/neighbors/brute_force-ext.cuh new file mode 100644 index 0000000000..98a186db86 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force-ext.cuh @@ -0,0 +1,109 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // raft::identity_op +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::brute_force { + +template +inline void knn_merge_parts( + raft::device_resources const& handle, + raft::device_matrix_view in_keys, + raft::device_matrix_view in_values, + raft::device_matrix_view out_keys, + raft::device_matrix_view out_values, + size_t n_samples, + std::optional> translations = std::nullopt) RAFT_EXPLICIT; + +template +void knn(raft::device_resources const& handle, + std::vector> index, + raft::device_matrix_view search, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + distance::DistanceType metric = distance::DistanceType::L2Unexpanded, + std::optional metric_arg = std::make_optional(2.0f), + std::optional global_id_offset = std::nullopt, + epilogue_op distance_epilogue = raft::identity_op()) RAFT_EXPLICIT; + +template +void fused_l2_knn(raft::device_resources const& handle, + raft::device_matrix_view index, + raft::device_matrix_view query, + raft::device_matrix_view out_inds, + raft::device_matrix_view out_dists, + raft::distance::DistanceType metric) RAFT_EXPLICIT; + +} // namespace raft::neighbors::brute_force + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +// No extern template for raft::neighbors::brute_force::knn_merge_parts + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + extern template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + int, float, int, raft::row_major, raft::row_major, raft::identity_op); +instantiate_raft_neighbors_brute_force_knn( + uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + extern template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major) + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/include/raft/neighbors/brute_force.cuh b/cpp/include/raft/neighbors/brute_force.cuh new file mode 100644 index 0000000000..8453a83df4 --- /dev/null +++ b/cpp/include/raft/neighbors/brute_force.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "brute_force-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "brute_force-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh new file mode 100644 index 0000000000..3bb3a4308d --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat::detail { + +template +void search(raft::device_resources const& handle, + const search_params& params, + const raft::neighbors::ivf_flat::index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr); + +} // namespace raft::neighbors::ivf_flat::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::search( \ + raft::device_resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh similarity index 78% rename from cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu rename to cpp/include/raft/neighbors/detail/ivf_flat_search.cuh index d777e73dc9..acf9d2c99d 100644 --- a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_float.cu +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search.cuh @@ -14,7 +14,12 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_flat_search-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_flat_search-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh new file mode 100644 index 0000000000..62b2b25261 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // size_t +#include // uint32_t +#include // RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::neighbors::detail { + +template +void select_k(const key_t* inK, + const payload_t* inV, + size_t n_rows, + size_t n_cols, + key_t* outK, + payload_t* outV, + bool select_min, + int k, + cudaStream_t stream) RAFT_EXPLICIT; +}; // namespace raft::neighbors::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + extern template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(uint32_t, float); +instantiate_raft_neighbors_detail_select_k(int32_t, float); +instantiate_raft_neighbors_detail_select_k(long, float); +instantiate_raft_neighbors_detail_select_k(size_t, double); +// test/neighbors/selection.cu +instantiate_raft_neighbors_detail_select_k(int, double); +instantiate_raft_neighbors_detail_select_k(size_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/include/raft/neighbors/detail/selection_faiss.cuh b/cpp/include/raft/neighbors/detail/selection_faiss.cuh new file mode 100644 index 0000000000..06b4478010 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "selection_faiss-inl.cuh" +#endif + +#if defined(RAFT_COMPILED) +#include "selection_faiss-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/ivf_flat-ext.cuh b/cpp/include/raft/neighbors/ivf_flat-ext.cuh new file mode 100644 index 0000000000..60edf8a068 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_flat-ext.cuh @@ -0,0 +1,185 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat { + +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index RAFT_EXPLICIT; + +template +auto build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) + -> index RAFT_EXPLICIT; + +template +void build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + const index& orig_index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + index* index, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* index) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); + +instantiate_raft_neighbors_ivf_flat_build(float, int64_t); +instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); +#undef instantiate_raft_neighbors_ivf_flat_build + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + extern template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + extern template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); + +instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); +instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + extern template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + extern template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); + +instantiate_raft_neighbors_ivf_flat_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu b/cpp/include/raft/neighbors/ivf_flat.cuh similarity index 78% rename from cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu rename to cpp/include/raft/neighbors/ivf_flat.cuh index f7825e577a..4906ddab60 100644 --- a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_float.cu +++ b/cpp/include/raft/neighbors/ivf_flat.cuh @@ -14,7 +14,12 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_flat-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_flat-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/ivf_pq-ext.cuh b/cpp/include/raft/neighbors/ivf_pq-ext.cuh new file mode 100644 index 0000000000..60588966d8 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq-ext.cuh @@ -0,0 +1,170 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // raft::neighbors::ivf_pq::index +#include // RAFT_EXPLICIT +#include // rmm::mr::device_memory_resource + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_pq { + +template +index build(raft::device_resources const& handle, + const index_params& params, + raft::device_matrix_view dataset) RAFT_EXPLICIT; + +template +index extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& idx) RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* idx) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const search_params& params, + const index& idx, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) RAFT_EXPLICIT; + +template +auto build(raft::device_resources const& handle, + const index_params& params, + const T* dataset, + IdxT n_rows, + uint32_t dim) -> index RAFT_EXPLICIT; + +template +auto extend(raft::device_resources const& handle, + const index& idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) -> index RAFT_EXPLICIT; + +template +void extend(raft::device_resources const& handle, + index* idx, + const T* new_vectors, + const IdxT* new_indices, + IdxT n_rows) RAFT_EXPLICIT; + +template +void search(raft::device_resources const& handle, + const raft::neighbors::ivf_pq::search_params& params, + const index& idx, + const T* queries, + uint32_t n_queries, + uint32_t k, + IdxT* neighbors, + float* distances, + rmm::mr::device_memory_resource* mr = nullptr) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_pq + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + extern template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, int64_t); +instantiate_raft_neighbors_ivf_pq_build(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + extern template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + extern template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + extern template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + extern template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(float, int64_t); +instantiate_raft_neighbors_ivf_pq_extend(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + extern template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(float, int64_t); +instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu b/cpp/include/raft/neighbors/ivf_pq.cuh similarity index 78% rename from cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu rename to cpp/include/raft/neighbors/ivf_pq.cuh index 7ea4b60e09..055d159b94 100644 --- a/cpp/src/distance/specializations/detail/kernels/rbf_kernel_double.cu +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -14,7 +14,12 @@ * limitations under the License. */ -#include -#include +#pragma once -template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_pq-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_pq-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/refine-ext.cuh b/cpp/include/raft/neighbors/refine-ext.cuh new file mode 100644 index 0000000000..edd14f1770 --- /dev/null +++ b/cpp/include/raft/neighbors/refine-ext.cuh @@ -0,0 +1,78 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // int64_t + +#include // raft::device_matrix_view +#include // raft::device_resources +#include // // raft::host_matrix_view +#include // raft::distance::DistanceType +#include // RAFT_EXPLICIT + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors { + +template +void refine(raft::device_resources const& handle, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded) + RAFT_EXPLICIT; + +template +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + raft::distance::DistanceType metric = distance::DistanceType::L2Unexpanded) + RAFT_EXPLICIT; + +} // namespace raft::neighbors + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + extern template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + extern template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/include/raft/neighbors/refine.cuh b/cpp/include/raft/neighbors/refine.cuh new file mode 100644 index 0000000000..7fe190493f --- /dev/null +++ b/cpp/include/raft/neighbors/refine.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "refine-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "refine-ext.cuh" +#endif diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh new file mode 100644 index 0000000000..199da01ddb --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers-ext.cuh @@ -0,0 +1,129 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "../../ball_cover_types.hpp" // BallCoverIndex +#include "registers_types.cuh" // DistFunc +#include // uint32_t +#include //RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::spatial::knn::detail { + +template +void rbc_low_dim_pass_one(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* dists_counter) RAFT_EXPLICIT; + +template +void rbc_low_dim_pass_two(raft::device_resources const& handle, + const BallCoverIndex& index, + const value_t* query, + const value_int n_query_rows, + value_int k, + const value_idx* R_knn_inds, + const value_t* R_knn_dists, + dist_func& dfunc, + value_idx* inds, + value_t* dists, + float weight, + value_int* post_dists_counter) RAFT_EXPLICIT; + +}; // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + extern template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh new file mode 100644 index 0000000000..b60cd645b4 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover/registers.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "registers-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "registers-ext.cuh" +#endif diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh new file mode 100644 index 0000000000..390436939f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn-ext.cuh @@ -0,0 +1,70 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // size_t +#include // uint32_t +#include // DistanceType +#include // RAFT_EXPLICIT + +#if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) + +namespace raft::spatial::knn::detail { + +template +void fusedL2Knn(size_t D, + value_idx* out_inds, + value_t* out_dists, + const value_t* index, + const value_t* query, + size_t n_index_rows, + size_t n_query_rows, + int k, + bool rowMajorIndex, + bool rowMajorQuery, + cudaStream_t stream, + raft::distance::DistanceType metric) RAFT_EXPLICIT; + +} // namespace raft::spatial::knn::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + extern template void \ + raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); + +// These are used by brute_force_knn: +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh new file mode 100644 index 0000000000..38dd2f332f --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "fused_l2_knn-inl.cuh" +#endif + +#if defined(RAFT_COMPILED) +#include "fused_l2_knn-ext.cuh" +#endif diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py new file mode 100644 index 0000000000..97fe120458 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_00_generate.py @@ -0,0 +1,194 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# NOTE: this template is not perfectly formatted. Use pre-commit to get +# everything in shape again. +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +""" + + +macro = """ +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \\ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \\ + template void raft::distance::detail:: \\ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \\ + OpT distance_op, \\ + IdxT m, \\ + IdxT n, \\ + IdxT k, \\ + const DataT* x, \\ + const DataT* y, \\ + const DataT* x_norm, \\ + const DataT* y_norm, \\ + OutT* out, \\ + FinOpT fin_op, \\ + cudaStream_t stream, \\ + bool is_row_major) +""" + +data_type_instances = [ + dict( + DataT="float", + AccT="float", + OutT="float", + IdxT="int", + ), + dict( + DataT="double", + AccT="double", + OutT="double", + IdxT="int", + ), +] + +op_instances = [ + dict( + path_prefix="canberra", + OpT="raft::distance::detail::ops::canberra_distance_op", + archs = [60], + ), + dict( + path_prefix="correlation", + OpT="raft::distance::detail::ops::correlation_distance_op", + archs = [60], + ), + dict( + path_prefix="cosine", + OpT="raft::distance::detail::ops::cosine_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="hamming_unexpanded", + OpT="raft::distance::detail::ops::hamming_distance_op", + archs = [60], + ), + dict( + path_prefix="hellinger_expanded", + OpT="raft::distance::detail::ops::hellinger_distance_op", + archs = [60], + ), + # inner product is handled by cublas. + dict( + path_prefix="jensen_shannon", + OpT="raft::distance::detail::ops::jensen_shannon_distance_op", + archs = [60], + ), + dict( + path_prefix="kl_divergence", + OpT="raft::distance::detail::ops::kl_divergence_op", + archs = [60], + ), + dict( + path_prefix="l1", + OpT="raft::distance::detail::ops::l1_distance_op", + archs = [60], + ), + dict( + path_prefix="l2_expanded", + OpT="raft::distance::detail::ops::l2_exp_distance_op", + archs = [60, 80], + ), + dict( + path_prefix="l2_unexpanded", + OpT="raft::distance::detail::ops::l2_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="l_inf", + OpT="raft::distance::detail::ops::l_inf_distance_op", + archs = [60], + ), + dict( + path_prefix="lp_unexpanded", + OpT="raft::distance::detail::ops::lp_unexp_distance_op", + archs = [60], + ), + dict( + path_prefix="russel_rao", + OpT="raft::distance::detail::ops::russel_rao_distance_op", + archs = [60], + ), +] + +def arch_headers(archs): + include_headers ="\n".join([ + f"#include " + for arch in archs + ]) + return include_headers + + + +for op in op_instances: + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + path = f"dispatch_{op['path_prefix']}_{DataT}_{AccT}_{OutT}_{IdxT}.cu" + with open(path, "w") as f: + f.write(header) + f.write(arch_headers(op["archs"])) + f.write(macro) + + OpT = op['OpT'] + FinOpT = "raft::identity_op" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + print(f"src/distance/detail/pairwise_matrix/{path}") + +# Dispatch kernels for with the RBF fin op. +with open("dispatch_rbf.cu", "w") as f: + OpT="raft::distance::detail::ops::l2_unexp_distance_op" + archs = [60] + + f.write(header) + f.write("#include // rbf_fin_op\n") + f.write(arch_headers(archs)) + f.write(macro) + + for dt in data_type_instances: + DataT, AccT, OutT, IdxT = (dt[k] for k in ["DataT", "AccT", "OutT", "IdxT"]); + IdxT = "int64_t" # overwrite IdxT + + FinOpT = f"raft::distance::kernels::detail::rbf_fin_op<{DataT}>" + f.write(f"\ninstantiate_raft_distance_detail_pairwise_matrix_dispatch({OpT}, {DataT}, {AccT}, {OutT}, {FinOpT}, {IdxT});\n") + + f.write("\n#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch\n") + +print("src/distance/detail/pairwise_matrix/dispatch_rbf.cu") diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu new file mode 100644 index 0000000000..41db12e9ae --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu new file mode 100644 index 0000000000..f038e53381 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::canberra_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu new file mode 100644 index 0000000000..52e4cc02d8 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu new file mode 100644 index 0000000000..c9481d6c22 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_correlation_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::correlation_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu new file mode 100644 index 0000000000..517858125b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu new file mode 100644 index 0000000000..62f1d9874b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_cosine_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::cosine_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..500f7b4a9c --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..3be7586b43 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hamming_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hamming_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu new file mode 100644 index 0000000000..023134ddff --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu new file mode 100644 index 0000000000..e438f121f2 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_hellinger_expanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::hellinger_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu new file mode 100644 index 0000000000..31c5003ad6 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu new file mode 100644 index 0000000000..e78c1c320a --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_jensen_shannon_float_float_float_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::jensen_shannon_distance_op, + float, + float, + float, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu new file mode 100644 index 0000000000..5b95df9614 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu new file mode 100644 index 0000000000..fb72c91b73 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_kl_divergence_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::kl_divergence_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu new file mode 100644 index 0000000000..cac5acad92 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu new file mode 100644 index 0000000000..78aa097961 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l1_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l1_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu new file mode 100644 index 0000000000..c8d922f6fa --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_double_double_double_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu new file mode 100644 index 0000000000..20cf57f898 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_expanded_float_float_float_int.cu @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_exp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..eadd0d2c2b --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..e4b5dd3a86 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l2_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu new file mode 100644 index 0000000000..45d021bce9 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_double_double_double_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu new file mode 100644 index 0000000000..ba48e52a18 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_l_inf_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l_inf_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu new file mode 100644 index 0000000000..ffa58793d9 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu new file mode 100644 index 0000000000..915c68f05f --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_lp_unexpanded_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::lp_unexp_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu new file mode 100644 index 0000000000..15855cea0a --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_rbf.cu @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // rbf_fin_op +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::l2_unexp_distance_op, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu new file mode 100644 index 0000000000..db45dc8b94 --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_double_double_double_int.cu @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, + double, + double, + double, + raft::identity_op, + int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu new file mode 100644 index 0000000000..a2a5a9fafe --- /dev/null +++ b/cpp/src/distance/detail/pairwise_matrix/dispatch_russel_rao_float_float_float_int.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by dispatch_00_generate.py + * + * Make changes there and run in this directory: + * + * > python dispatch_00_generate.py + * + */ + +#include // raft::identity_op +#include // ops::* +#include // dispatch +#include +#define instantiate_raft_distance_detail_pairwise_matrix_dispatch( \ + OpT, DataT, AccT, OutT, FinOpT, IdxT) \ + template void raft::distance::detail:: \ + pairwise_matrix_dispatch, DataT, AccT, OutT, FinOpT, IdxT>( \ + OpT distance_op, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + const DataT* x, \ + const DataT* y, \ + const DataT* x_norm, \ + const DataT* y_norm, \ + OutT* out, \ + FinOpT fin_op, \ + cudaStream_t stream, \ + bool is_row_major) + +instantiate_raft_distance_detail_pairwise_matrix_dispatch( + raft::distance::detail::ops::russel_rao_distance_op, float, float, float, raft::identity_op, int); + +#undef instantiate_raft_distance_detail_pairwise_matrix_dispatch diff --git a/cpp/src/distance/distance.cu b/cpp/src/distance/distance.cu new file mode 100644 index 0000000000..8c94608311 --- /dev/null +++ b/cpp/src/distance/distance.cu @@ -0,0 +1,934 @@ +/* + * Copyright (c) 2018-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // rbf_fin_op +#include + +/* + * Hierarchy of instantiations: + * + * This file defines the template instantiations for the public API of + * raft::distance. To improve compile times, the compilation of the distance + * kernels is handled in distance/detail/pairwise_matrix/dispatch_*.cu. + * + */ + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +// The following two instances are used in test/distance/gram.cu. Note the use +// of int64_t for the index type. +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + float, + float, + float, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::distance::kernels::detail::rbf_fin_op, + int64_t); + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::identity_op, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, raft::identity_op, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, raft::identity_op, int); + +#undef instantiate_raft_distance_distance + +// Same, but without raft::identity_op +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +// Same, but without workspace +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CorrelationExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::CosineExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HammingUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::HellingerExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtExpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2SqrtUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Linf, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::LpUnexpanded, double, double, double, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, float, float, float, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::RusselRaoExpanded, double, double, double, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT, layout) \ + template size_t raft::distance::getWorkspaceSize( \ + raft::device_matrix_view const& x, \ + raft::device_matrix_view const& y) + +// We could consider not taking template parameters for this function. The +// number of instantiations seems a bit excessive.. +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::Canberra, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::InnerProduct, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::JensenShannon, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::KLDivergence, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L1, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, float, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, double, int, raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + int, + raft::layout_f_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + int, + raft::layout_c_contiguous); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Unexpanded, float, float, float, int, raft::layout_f_contiguous); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + rmm::device_uvector& workspace, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Same, but without workspace +#define instantiate_raft_distance_pairwise_distance(DataT, IdxT) \ + template void raft::distance::pairwise_distance(raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + DataT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + raft::distance::DistanceType metric, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, int); +instantiate_raft_distance_pairwise_distance(double, int); + +#undef instantiate_raft_distance_pairwise_distance + +// Version with mdspan +#define instantiate_raft_distance_distance(DistT, DataT, AccT, OutT, layout, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + DataT metric_arg) + +// Again, we might want to consider reigning in the number of instantiations... +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Canberra, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CorrelationExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::CosineExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HammingUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::HellingerExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::InnerProduct, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::InnerProduct, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::JensenShannon, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::JensenShannon, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::KLDivergence, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::KLDivergence, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L1, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Expanded, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2SqrtUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::L2Unexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Unexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::Linf, double, double, double, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_c_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance( + raft::distance::DistanceType::LpUnexpanded, float, float, float, raft::layout_f_contiguous, int); +instantiate_raft_distance_distance(raft::distance::DistanceType::LpUnexpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_c_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + float, + float, + float, + raft::layout_f_contiguous, + int); +instantiate_raft_distance_distance(raft::distance::DistanceType::RusselRaoExpanded, + double, + double, + double, + raft::layout_f_contiguous, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_pairwise_distance(DataT, layout, IdxT) \ + template void raft::distance::pairwise_distance( \ + raft::resources const& handle, \ + raft::device_matrix_view const x, \ + raft::device_matrix_view const y, \ + raft::device_matrix_view dist, \ + raft::distance::DistanceType metric, \ + DataT metric_arg) + +instantiate_raft_distance_pairwise_distance(float, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(float, raft::layout_f_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_c_contiguous, int); +instantiate_raft_distance_pairwise_distance(double, raft::layout_f_contiguous, int); + +#undef instantiate_raft_distance_pairwise_distance diff --git a/cpp/src/distance/fused_l2_nn.cu b/cpp/src/distance/fused_l2_nn.cu new file mode 100644 index 0000000000..6011aaec29 --- /dev/null +++ b/cpp/src/distance/fused_l2_nn.cu @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // int64_t +#include // raft::KeyValuePair +#include + +#define instantiate_raft_distance_fusedL2NNMinReduce(DataT, OutT, IdxT) \ + template void raft::distance::fusedL2NNMinReduce(OutT * min, \ + const DataT* x, \ + const DataT* y, \ + const DataT* xn, \ + const DataT* yn, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + bool sqrt, \ + bool initOutBuffer, \ + cudaStream_t stream) + +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, double, int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, float, int64_t); + +// We can't have comma's in the macro expansion, so we use the COMMA macro: +#define COMMA , + +instantiate_raft_distance_fusedL2NNMinReduce(double, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(double, + raft::KeyValuePair, + int64_t); +instantiate_raft_distance_fusedL2NNMinReduce(float, raft::KeyValuePair, int); +instantiate_raft_distance_fusedL2NNMinReduce(float, + raft::KeyValuePair, + int64_t); + +#undef COMMA + +#undef instantiate_raft_distance_fusedL2NNMinReduce diff --git a/cpp/src/distance/specializations/detail/00_write_template.py b/cpp/src/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 3f2f853569..0000000000 --- a/cpp/src/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,159 +0,0 @@ -#!/usr/bin/env python3 - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -INCLUDE_SM_HEADERS - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - archs = [60], - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - archs = [60], - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - archs = [60, 80], - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - archs = [60], - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - archs = [60], - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - archs = [60], - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - archs = [60], - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - archs = [60], - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - archs = [60, 80], - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - archs = [60], - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - archs = [60], - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - archs = [60], - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - archs = [60], - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -def fill_include_sm_headers(op_instance): - include_headers ="\n".join([ - f"#include " - for arch in op_instance["archs"] - ]) - - return { - "path_prefix": op_instance["path_prefix"], - "OpT": op_instance["OpT"], - "INCLUDE_SM_HEADERS": include_headers - } - -for op_instance in op_instances: - op_instance = fill_include_sm_headers(op_instance) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "decltype(raft::identity_op())", - } - - text = fill_in(template, instance) - - path = fill_in("path_prefix_DataT_AccT_OutT_IdxT.cu", instance) - with open(path, "w") as f: - f.write(text) diff --git a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu b/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu deleted file mode 100644 index 037d218178..0000000000 --- a/cpp/src/distance/specializations/detail/canberra_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu b/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu deleted file mode 100644 index 0ed8ea7bb0..0000000000 --- a/cpp/src/distance/specializations/detail/canberra_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu b/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu deleted file mode 100644 index 0c11f0621e..0000000000 --- a/cpp/src/distance/specializations/detail/correlation_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu b/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu deleted file mode 100644 index 396e158554..0000000000 --- a/cpp/src/distance/specializations/detail/correlation_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu b/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu deleted file mode 100644 index e9afb6f563..0000000000 --- a/cpp/src/distance/specializations/detail/cosine_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu b/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu deleted file mode 100644 index 1033c491d6..0000000000 --- a/cpp/src/distance/specializations/detail/cosine_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu deleted file mode 100644 index 195115914d..0000000000 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu deleted file mode 100644 index a74c6c404e..0000000000 --- a/cpp/src/distance/specializations/detail/hamming_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu deleted file mode 100644 index bac1dd7bd0..0000000000 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu deleted file mode 100644 index 77c113b1a9..0000000000 --- a/cpp/src/distance/specializations/detail/hellinger_expanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu b/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu deleted file mode 100644 index 3db0a3572e..0000000000 --- a/cpp/src/distance/specializations/detail/inner_product_double_double_double_int.cu +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - double metric_arg); - -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu b/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu deleted file mode 100644 index 2b06ca4dc2..0000000000 --- a/cpp/src/distance/specializations/detail/inner_product_float_float_float_int.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft { -namespace distance { -namespace detail { -template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - std::size_t worksize, - bool isRowMajor, - float metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu deleted file mode 100644 index 188e52c152..0000000000 --- a/cpp/src/distance/specializations/detail/jensen_shannon_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void - pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu b/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu deleted file mode 100644 index b0afbf7bb2..0000000000 --- a/cpp/src/distance/specializations/detail/jensen_shannon_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void - pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu b/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu deleted file mode 100644 index 7c80eb29d0..0000000000 --- a/cpp/src/distance/specializations/detail/kernels/gram_matrix_base_double.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -template class raft::distance::kernels::detail::GramMatrixBase; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu b/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu deleted file mode 100644 index 28306d0c21..0000000000 --- a/cpp/src/distance/specializations/detail/kernels/polynomial_kernel_double_int.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -template class raft::distance::kernels::detail::PolynomialKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu b/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu deleted file mode 100644 index ab818db73b..0000000000 --- a/cpp/src/distance/specializations/detail/kernels/tanh_kernel_double.cu +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -template class raft::distance::kernels::detail::TanhKernel; \ No newline at end of file diff --git a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu deleted file mode 100644 index f06ae85414..0000000000 --- a/cpp/src/distance/specializations/detail/kl_divergence_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu b/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu deleted file mode 100644 index 00d5a5ee5b..0000000000 --- a/cpp/src/distance/specializations/detail/kl_divergence_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu deleted file mode 100644 index 5c235316da..0000000000 --- a/cpp/src/distance/specializations/detail/l1_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu deleted file mode 100644 index fb293ca83d..0000000000 --- a/cpp/src/distance/specializations/detail/l1_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu deleted file mode 100644 index 2c02f0224f..0000000000 --- a/cpp/src/distance/specializations/detail/l2_expanded_double_double_double_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu deleted file mode 100644 index 85e25a25ca..0000000000 --- a/cpp/src/distance/specializations/detail/l2_expanded_float_float_float_int.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu deleted file mode 100644 index 5b4d995d14..0000000000 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu deleted file mode 100644 index a63c3f0bb8..0000000000 --- a/cpp/src/distance/specializations/detail/l2_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu b/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu deleted file mode 100644 index 831167523f..0000000000 --- a/cpp/src/distance/specializations/detail/l_inf_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu b/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu deleted file mode 100644 index 02e667cbe3..0000000000 --- a/cpp/src/distance/specializations/detail/l_inf_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu deleted file mode 100644 index ebd71065ec..0000000000 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu b/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu deleted file mode 100644 index b94a81fdce..0000000000 --- a/cpp/src/distance/specializations/detail/lp_unexpanded_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu b/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu deleted file mode 100644 index 6f952fcc37..0000000000 --- a/cpp/src/distance/specializations/detail/russel_rao_double_double_double_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - double, - double, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu b/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu deleted file mode 100644 index 3223ce33a7..0000000000 --- a/cpp/src/distance/specializations/detail/russel_rao_float_float_float_int.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include // raft::identity_op -#include // ops::* -#include // pairwise_matrix_instantiation_point -#include - -namespace raft::distance::detail { - -template void pairwise_matrix_instantiation_point, - int, - float, - float, - decltype(raft::identity_op())>( - ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -} // namespace raft::distance::detail diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int.cu deleted file mode 100644 index b49132b042..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu deleted file mode 100644 index b1e3a900a9..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_double_int64.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int.cu deleted file mode 100644 index 44b4953d8c..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu b/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu deleted file mode 100644 index 9ca2b639a9..0000000000 --- a/cpp/src/distance/specializations/fused_l2_nn_float_int64.cu +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace distance { - -template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft diff --git a/cpp/src/linalg/detail/coalesced_reduction.cu b/cpp/src/linalg/detail/coalesced_reduction.cu new file mode 100644 index 0000000000..00d025df46 --- /dev/null +++ b/cpp/src/linalg/detail/coalesced_reduction.cu @@ -0,0 +1,69 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// #include + +#include + +#define instantiate_raft_linalg_detail_coalescedReduction( \ + InType, OutType, IdxType, MainLambda, ReduceLambda, FinalLambda) \ + template void raft::linalg::detail::coalescedReduction(OutType* dots, \ + const InType* data, \ + IdxType D, \ + IdxType N, \ + OutType init, \ + cudaStream_t stream, \ + bool inplace, \ + MainLambda main_op, \ + ReduceLambda reduce_op, \ + FinalLambda final_op) + +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + double, double, int, raft::abs_op, raft::max_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::abs_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::identity_op, raft::min_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, int, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, long, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::identity_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::identity_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::abs_op, raft::max_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, size_t, raft::sq_op, raft::add_op, raft::sqrt_op); +instantiate_raft_linalg_detail_coalescedReduction( + float, float, unsigned int, raft::sq_op, raft::add_op, raft::identity_op); + +#undef instantiate_raft_linalg_detail_coalescedReduction diff --git a/cpp/src/matrix/detail/select_k_double_int64_t.cu b/cpp/src/matrix/detail/select_k_double_int64_t.cu new file mode 100644 index 0000000000..022627283a --- /dev/null +++ b/cpp/src/matrix/detail/select_k_double_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_double_uint32_t.cu b/cpp/src/matrix/detail/select_k_double_uint32_t.cu new file mode 100644 index 0000000000..22c6989337 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_double_uint32_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // uint32_t +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(double, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_int64_t.cu b/cpp/src/matrix/detail/select_k_float_int64_t.cu new file mode 100644 index 0000000000..1f1d686048 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_float_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/detail/select_k_float_uint32_t.cu new file mode 100644 index 0000000000..3bb47acbf2 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_float_uint32_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(float, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_int64_t.cu b/cpp/src/matrix/detail/select_k_half_int64_t.cu new file mode 100644 index 0000000000..cf4e15959d --- /dev/null +++ b/cpp/src/matrix/detail/select_k_half_int64_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, int64_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/detail/select_k_half_uint32_t.cu new file mode 100644 index 0000000000..b18887bfc0 --- /dev/null +++ b/cpp/src/matrix/detail/select_k_half_uint32_t.cu @@ -0,0 +1,33 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_matrix_detail_select_k(T, IdxT) \ + template void raft::matrix::detail::select_k(const T* in_val, \ + const IdxT* in_idx, \ + size_t batch_size, \ + size_t len, \ + int k, \ + T* out_val, \ + IdxT* out_idx, \ + bool select_min, \ + rmm::cuda_stream_view stream, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_matrix_detail_select_k(__half, uint32_t); + +#undef instantiate_raft_matrix_detail_select_k diff --git a/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu b/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu deleted file mode 100644 index 370ab1ba50..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_float_int64_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(float, int64_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu b/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu deleted file mode 100644 index c6733c2a46..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_float_uint32_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(float, uint32_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu b/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu deleted file mode 100644 index 38e28ac54d..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_half_int64_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(half, int64_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu b/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu deleted file mode 100644 index 108bd30b49..0000000000 --- a/cpp/src/matrix/specializations/detail/select_k_half_uint32_t.cu +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -RAFT_INST(half, uint32_t); - -} // namespace raft::matrix::detail diff --git a/cpp/src/neighbors/ball_cover.cu b/cpp/src/neighbors/ball_cover.cu new file mode 100644 index 0000000000..4c49c1847b --- /dev/null +++ b/cpp/src/neighbors/ball_cover.cu @@ -0,0 +1,66 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#define instantiate_raft_neighbors_ball_cover(idx_t, value_t, int_t, matrix_idx_t) \ + template void raft::neighbors::ball_cover::build_index( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index); \ + \ + template void raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::all_knn_query( \ + raft::device_resources const& handle, \ + raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + int_t k, \ + const value_t* query, \ + int_t n_query_pts, \ + idx_t* inds, \ + value_t* dists, \ + bool perform_post_filtering, \ + float weight); \ + \ + template void raft::neighbors::ball_cover::knn_query( \ + raft::device_resources const& handle, \ + const raft::neighbors::ball_cover::BallCoverIndex& index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view inds, \ + raft::device_matrix_view dists, \ + int_t k, \ + bool perform_post_filtering, \ + float weight); + +instantiate_raft_neighbors_ball_cover(int64_t, float, uint32_t, uint32_t); + +#undef instantiate_raft_neighbors_ball_cover diff --git a/cpp/src/neighbors/brute_force_00_generate.py b/cpp/src/neighbors/brute_force_00_generate.py new file mode 100644 index 0000000000..251dd53b1c --- /dev/null +++ b/cpp/src/neighbors/brute_force_00_generate.py @@ -0,0 +1,106 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +""" + +knn_macro = """ +#define instantiate_raft_neighbors_brute_force_knn(idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \\ + template void raft::neighbors::brute_force::knn( \\ + raft::device_resources const& handle, \\ + std::vector> index, \\ + raft::device_matrix_view search, \\ + raft::device_matrix_view indices, \\ + raft::device_matrix_view distances, \\ + raft::distance::DistanceType metric, \\ + std::optional metric_arg, \\ + std::optional global_id_offset, \\ + epilogue_op distance_epilogue); + +""" + +fused_l2_knn_macro = """ +#define instantiate_raft_neighbors_brute_force_fused_l2_knn(value_t, idx_t, idx_layout, query_layout) \\ + template void raft::neighbors::brute_force::fused_l2_knn( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view index, \\ + raft::device_matrix_view query, \\ + raft::device_matrix_view out_inds, \\ + raft::device_matrix_view out_dists, \\ + raft::distance::DistanceType metric); + +""" + +knn_types = dict( + int64_t_float_uint32_t=("int64_t","float","uint32_t"), + int64_t_float_int64_t=("int64_t","float","int64_t"), + int_float_int=("int","float","int"), + uint32_t_float_uint32_t=("uint32_t","float","uint32_t"), +) + +fused_l2_knn_types = dict( + float_int64_t=("float", "int64_t"), +) + +# knn +for type_path, (idx_t, value_t, matrix_idx) in knn_types.items(): + path = f"brute_force_knn_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(knn_macro) + f.write(f"instantiate_raft_neighbors_brute_force_knn({idx_t},{value_t},{matrix_idx},raft::row_major,raft::row_major,raft::identity_op);\n\n") + f.write("#undef instantiate_raft_neighbors_brute_force_knn\n") + + # For pasting into CMakeLists.txt + print(f"src/neighbors/{path}") + +#fused_l2_knn +for type_path, (value_t, idx_t) in fused_l2_knn_types.items(): + path = f"brute_force_fused_l2_knn_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(fused_l2_knn_macro) + f.write(f"instantiate_raft_neighbors_brute_force_fused_l2_knn({value_t},{idx_t},raft::row_major,raft::row_major);\n\n") + f.write("#undef instantiate_raft_neighbors_brute_force_fused_l2_knn\n") + + # For pasting into CMakeLists.txt + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu b/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu new file mode 100644 index 0000000000..4e1805f9a8 --- /dev/null +++ b/cpp/src/neighbors/brute_force_fused_l2_knn_float_int64_t.cu @@ -0,0 +1,45 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_fused_l2_knn( \ + value_t, idx_t, idx_layout, query_layout) \ + template void raft::neighbors::brute_force::fused_l2_knn( \ + raft::device_resources const& handle, \ + raft::device_matrix_view index, \ + raft::device_matrix_view query, \ + raft::device_matrix_view out_inds, \ + raft::device_matrix_view out_dists, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_brute_force_fused_l2_knn(float, + int64_t, + raft::row_major, + raft::row_major); + +#undef instantiate_raft_neighbors_brute_force_fused_l2_knn diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu new file mode 100644 index 0000000000..a668b076d6 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float_int64_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, int64_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu b/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu new file mode 100644 index 0000000000..21cac5034a --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int64_t_float_uint32_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int64_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_int_float_int.cu b/cpp/src/neighbors/brute_force_knn_int_float_int.cu new file mode 100644 index 0000000000..b76fe09c2a --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_int_float_int.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + int, float, int, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu b/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu new file mode 100644 index 0000000000..4d3f627182 --- /dev/null +++ b/cpp/src/neighbors/brute_force_knn_uint32_t_float_uint32_t.cu @@ -0,0 +1,47 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by brute_force_00_generate.py + * + * Make changes there and run in this directory: + * + * > python brute_force_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_brute_force_knn( \ + idx_t, value_t, matrix_idx, index_layout, search_layout, epilogue_op) \ + template void raft::neighbors::brute_force:: \ + knn( \ + raft::device_resources const& handle, \ + std::vector> index, \ + raft::device_matrix_view search, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric, \ + std::optional metric_arg, \ + std::optional global_id_offset, \ + epilogue_op distance_epilogue); + +instantiate_raft_neighbors_brute_force_knn( + uint32_t, float, uint32_t, raft::row_major, raft::row_major, raft::identity_op); + +#undef instantiate_raft_neighbors_brute_force_knn diff --git a/cpp/src/neighbors/detail/ivf_flat_search.cu b/cpp/src/neighbors/detail/ivf_flat_search.cu new file mode 100644 index 0000000000..345a8f499d --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_search.cu @@ -0,0 +1,35 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::detail::search( \ + raft::device_resources const& handle, \ + const search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_flat_detail_search(float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_search diff --git a/cpp/src/neighbors/detail/selection_faiss_00_generate.py b/cpp/src/neighbors/detail/selection_faiss_00_generate.py new file mode 100644 index 0000000000..36ba56c9b3 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_00_generate.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \\ + template void raft::neighbors::detail::select_k(const key_t* inK, \\ + const payload_t* inV, \\ + size_t n_rows, \\ + size_t n_cols, \\ + key_t* outK, \\ + payload_t* outV, \\ + bool select_min, \\ + int k, \\ + cudaStream_t stream) + +""" + +types = dict( + uint32_t_float=("uint32_t", "float"), + int32_t_float=("int32_t", "float"), + long_float=("long", "float"), + size_t_double=("size_t", "double"), + int_double=("int", "double"), + size_t_float=("size_t", "float"), +) + +for type_path, (payload_t, key_t) in types.items(): + path = f"selection_faiss_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_detail_select_k({payload_t}, {key_t});\n\n") + f.write(f"#undef instantiate_raft_neighbors_detail_select_k\n") + + # for pasting into CMakeLists.txt + print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu new file mode 100644 index 0000000000..1f1ece05ae --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_int32_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(int32_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_int_double.cu b/cpp/src/neighbors/detail/selection_faiss_int_double.cu new file mode 100644 index 0000000000..7e832410c4 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_int_double.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(int, double); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_long_float.cu b/cpp/src/neighbors/detail/selection_faiss_long_float.cu new file mode 100644 index 0000000000..441d54fa30 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_long_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(long, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu b/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu new file mode 100644 index 0000000000..ca310e7697 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_size_t_double.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(size_t, double); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu new file mode 100644 index 0000000000..a830e6ecac --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_size_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(size_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu b/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu new file mode 100644 index 0000000000..2fecaa5cf1 --- /dev/null +++ b/cpp/src/neighbors/detail/selection_faiss_uint32_t_float.cu @@ -0,0 +1,44 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by selection_faiss_00_generate.py + * + * Make changes there and run in this directory: + * + * > python selection_faiss_00_generate.py + * + */ + +#include // size_t +#include // uint32_t +#include + +#define instantiate_raft_neighbors_detail_select_k(payload_t, key_t) \ + template void raft::neighbors::detail::select_k(const key_t* inK, \ + const payload_t* inV, \ + size_t n_rows, \ + size_t n_cols, \ + key_t* outK, \ + payload_t* outV, \ + bool select_min, \ + int k, \ + cudaStream_t stream) + +instantiate_raft_neighbors_detail_select_k(uint32_t, float); + +#undef instantiate_raft_neighbors_detail_select_k diff --git a/cpp/src/neighbors/ivf_flat_00_generate.py b/cpp/src/neighbors/ivf_flat_00_generate.py new file mode 100644 index 0000000000..44ea9709c2 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_00_generate.py @@ -0,0 +1,148 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include +""" + +types = dict( + float_int64_t= ("float", "int64_t"), + int8_t_int64_t=("int8_t", "int64_t"), + uint8_t_int64_t=("uint8_t", "int64_t"), +) + +build_macro = """ +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + const T* dataset, \\ + IdxT n_rows, \\ + uint32_t dim) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template auto raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::build( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index_params& params, \\ + raft::device_matrix_view dataset, \\ + raft::neighbors::ivf_flat::index& idx); +""" + +extend_macro = """ +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::index& orig_index, \\ + const T* new_vectors, \\ + const IdxT* new_indices, \\ + IdxT n_rows) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template auto raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + const raft::neighbors::ivf_flat::index& orig_index) \\ + ->raft::neighbors::ivf_flat::index; \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::neighbors::ivf_flat::index* index, \\ + const T* new_vectors, \\ + const IdxT* new_indices, \\ + IdxT n_rows); \\ + \\ + template void raft::neighbors::ivf_flat::extend( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view new_vectors, \\ + std::optional> new_indices, \\ + raft::neighbors::ivf_flat::index* index); +""" + +search_macro = """ +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::search_params& params, \\ + const raft::neighbors::ivf_flat::index& index, \\ + const T* queries, \\ + uint32_t n_queries, \\ + uint32_t k, \\ + IdxT* neighbors, \\ + float* distances, \\ + rmm::mr::device_memory_resource* mr ); \\ + \\ + template void raft::neighbors::ivf_flat::search( \\ + raft::device_resources const& handle, \\ + const raft::neighbors::ivf_flat::search_params& params, \\ + const raft::neighbors::ivf_flat::index& index, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbors, \\ + raft::device_matrix_view distances); +""" + +macros = dict( + build=dict( + definition=build_macro, + name="instantiate_raft_neighbors_ivf_flat_build"), + extend=dict( + definition=extend_macro, + name="instantiate_raft_neighbors_ivf_flat_extend"), + search=dict( + definition=search_macro, + name="instantiate_raft_neighbors_ivf_flat_search"), +) + +for type_path, (T, IdxT) in types.items(): + for macro_path, macro in macros.items(): + path = f"ivf_flat_{macro_path}_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro['definition']) + + + f.write(f"{macro['name']}({T}, {IdxT});\n\n") + f.write(f"#undef {macro['name']}\n") + + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu new file mode 100644 index 0000000000..622f7c7d90 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_float_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..7b1eeae32d --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_int8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..40cf28151f --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_build_uint8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_build(T, IdxT) \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index_params& params, \ + raft::device_matrix_view dataset, \ + raft::neighbors::ivf_flat::index& idx); +instantiate_raft_neighbors_ivf_flat_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_build diff --git a/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu new file mode 100644 index 0000000000..f7d99d7081 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_float_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..9eec4f9648 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_int8_t_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..fc24cbff74 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_extend_uint8_t_int64_t.cu @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_extend(T, IdxT) \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::index& orig_index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template auto raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_flat::index& orig_index) \ + ->raft::neighbors::ivf_flat::index; \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_flat::index* index, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); \ + \ + template void raft::neighbors::ivf_flat::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_flat::index* index); +instantiate_raft_neighbors_ivf_flat_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_extend diff --git a/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu new file mode 100644 index 0000000000..5a1fae6d5a --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_float_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..bc84159a41 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_int8_t_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..9e70e21af4 --- /dev/null +++ b/cpp/src/neighbors/ivf_flat_search_uint8_t_int64_t.cu @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_flat_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_flat_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_search(T, IdxT) \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr); \ + \ + template void raft::neighbors::ivf_flat::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_flat::search_params& params, \ + const raft::neighbors::ivf_flat::index& index, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); +instantiate_raft_neighbors_ivf_flat_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_search diff --git a/cpp/src/neighbors/ivfpq_build_float_int64_t.cu b/cpp/src/neighbors/ivfpq_build_float_int64_t.cu new file mode 100644 index 0000000000..6771964cae --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_float_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu new file mode 100644 index 0000000000..759045faa7 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_int8_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu new file mode 100644 index 0000000000..62a47e9bcf --- /dev/null +++ b/cpp/src/neighbors/ivfpq_build_uint8_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_build(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + raft::device_matrix_view dataset); \ + \ + template auto raft::neighbors::ivf_pq::build( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index_params& params, \ + const T* dataset, \ + IdxT n_rows, \ + uint32_t dim) \ + ->raft::neighbors::ivf_pq::index; + +instantiate_raft_neighbors_ivf_pq_build(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_build diff --git a/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu new file mode 100644 index 0000000000..3e728be38d --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_float_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu new file mode 100644 index 0000000000..7853e53f63 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_int8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu new file mode 100644 index 0000000000..599a88fc67 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_extend_uint8_t_int64_t.cu @@ -0,0 +1,50 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_extend(T, IdxT) \ + template raft::neighbors::ivf_pq::index raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + const raft::neighbors::ivf_pq::index& idx); \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::device_matrix_view new_vectors, \ + std::optional> new_indices, \ + raft::neighbors::ivf_pq::index* idx); \ + \ + template auto raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows) \ + ->raft::neighbors::ivf_pq::index; \ + \ + template void raft::neighbors::ivf_pq::extend( \ + raft::device_resources const& handle, \ + raft::neighbors::ivf_pq::index* idx, \ + const T* new_vectors, \ + const IdxT* new_indices, \ + IdxT n_rows); + +instantiate_raft_neighbors_ivf_pq_extend(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_extend diff --git a/cpp/src/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu new file mode 100644 index 0000000000..ab946d2b65 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_search_float_int64_t.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(float, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu new file mode 100644 index 0000000000..af54a9312a --- /dev/null +++ b/cpp/src/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(int8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu new file mode 100644 index 0000000000..7b49487506 --- /dev/null +++ b/cpp/src/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include // raft::neighbors::ivf_pq::index + +#define instantiate_raft_neighbors_ivf_pq_search(T, IdxT) \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbors, \ + raft::device_matrix_view distances); \ + \ + template void raft::neighbors::ivf_pq::search( \ + raft::device_resources const& handle, \ + const raft::neighbors::ivf_pq::search_params& params, \ + const raft::neighbors::ivf_pq::index& idx, \ + const T* queries, \ + uint32_t n_queries, \ + uint32_t k, \ + IdxT* neighbors, \ + float* distances, \ + rmm::mr::device_memory_resource* mr) + +instantiate_raft_neighbors_ivf_pq_search(uint8_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_pq_search diff --git a/cpp/src/neighbors/refine_00_generate.py b/cpp/src/neighbors/refine_00_generate.py new file mode 100644 index 0000000000..18c8857e3f --- /dev/null +++ b/cpp/src/neighbors/refine_00_generate.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \\ + template void raft::neighbors::refine( \\ + raft::device_resources const& handle, \\ + raft::device_matrix_view dataset, \\ + raft::device_matrix_view queries, \\ + raft::device_matrix_view neighbor_candidates, \\ + raft::device_matrix_view indices, \\ + raft::device_matrix_view distances, \\ + raft::distance::DistanceType metric); \\ + \\ + template void raft::neighbors::refine( \\ + raft::device_resources const& handle, \\ + raft::host_matrix_view dataset, \\ + raft::host_matrix_view queries, \\ + raft::host_matrix_view neighbor_candidates, \\ + raft::host_matrix_view indices, \\ + raft::host_matrix_view distances, \\ + raft::distance::DistanceType metric); + +""" + +types = dict( + float_float= ("float", "float"), + int8_t_float=("int8_t", "float"), + uint8_t_float=("uint8_t", "float"), +) + +for type_path, (data_t, distance_t) in types.items(): + path = f"refine_{type_path}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_refine(int64_t, {data_t}, {distance_t}, int64_t);\n\n") + f.write(f"#undef instantiate_raft_neighbors_refine\n") + + # for pasting into CMakeLists.txt + print(f"src/neighbors/{path}") diff --git a/cpp/src/neighbors/refine_float_float.cu b/cpp/src/neighbors/refine_float_float.cu new file mode 100644 index 0000000000..7e811fd7e3 --- /dev/null +++ b/cpp/src/neighbors/refine_float_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, float, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_int8_t_float.cu b/cpp/src/neighbors/refine_int8_t_float.cu new file mode 100644 index 0000000000..6983c2492c --- /dev/null +++ b/cpp/src/neighbors/refine_int8_t_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, int8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/refine_uint8_t_float.cu b/cpp/src/neighbors/refine_uint8_t_float.cu new file mode 100644 index 0000000000..f61bc508c0 --- /dev/null +++ b/cpp/src/neighbors/refine_uint8_t_float.cu @@ -0,0 +1,50 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by refine_00_generate.py + * + * Make changes there and run in this directory: + * + * > python refine_00_generate.py + * + */ + +#include + +#define instantiate_raft_neighbors_refine(idx_t, data_t, distance_t, matrix_idx) \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + raft::distance::DistanceType metric); \ + \ + template void raft::neighbors::refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ + raft::distance::DistanceType metric); + +instantiate_raft_neighbors_refine(int64_t, uint8_t, float, int64_t); + +#undef instantiate_raft_neighbors_refine diff --git a/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu b/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu deleted file mode 100644 index 305dd6796e..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_all_knn_query.cu +++ /dev/null @@ -1,33 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/ball_cover_build_index.cu b/cpp/src/neighbors/specializations/ball_cover_build_index.cu deleted file mode 100644 index ec7f4bcf52..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_build_index.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -template class BallCoverIndex; -template class BallCoverIndex; - -template void build_index( - raft::device_resources const& handle, - BallCoverIndex& index); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/ball_cover_knn_query.cu b/cpp/src/neighbors/specializations/ball_cover_knn_query.cu deleted file mode 100644 index 634427200e..0000000000 --- a/cpp/src/neighbors/specializations/ball_cover_knn_query.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -#include - -namespace raft::neighbors::ball_cover { -template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu deleted file mode 100644 index b69751a62a..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_2d.cu +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_one( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu deleted file mode 100644 index ca44ad3165..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_one_3d.cu +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_one( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* dists_counter); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu deleted file mode 100644 index ba44327653..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_2d.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu b/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu deleted file mode 100644 index 59132c1f99..0000000000 --- a/cpp/src/neighbors/specializations/detail/ball_cover_lowdim_pass_two_3d.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void rbc_low_dim_pass_two( - raft::device_resources const& handle, - const BallCoverIndex& index, - const float* query, - const std::uint32_t n_query_rows, - std::uint32_t k, - const std::int64_t* R_knn_inds, - const float* R_knn_dists, - DistFunc& dfunc, - std::int64_t* inds, - float* dists, - float weight, - std::uint32_t* post_dists_counter); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft \ No newline at end of file diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu deleted file mode 100644 index 04aa42c9f1..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_int.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu deleted file mode 100644 index a8b9d4299a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_long_float_uint.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu deleted file mode 100644 index c97e6e936a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_int.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(uint32_t, float, int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu b/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu deleted file mode 100644 index 87451c385a..0000000000 --- a/cpp/src/neighbors/specializations/detail/brute_force_knn_impl_uint_float_uint.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu deleted file mode 100644 index 33c4e7ffc0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_fast.cu +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu deleted file mode 100644 index f543369de5..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu deleted file mode 100644 index 1a0322a722..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu deleted file mode 100644 index c7b5c9ffe9..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu deleted file mode 100644 index efb2a477a7..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu deleted file mode 100644 index b9051eb011..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8s_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu deleted file mode 100644 index c6b1bad123..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_fast.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu deleted file mode 100644 index d6033345da..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu deleted file mode 100644 index 1add18cb4a..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_fp8u_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu deleted file mode 100644 index 6020d7035b..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu deleted file mode 100644 index 62be67e1a9..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu deleted file mode 100644 index 145312f334..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_float_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu deleted file mode 100644 index c9365e1bb4..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu deleted file mode 100644 index d5c6934da2..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu deleted file mode 100644 index bac8c8706b..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8s_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu deleted file mode 100644 index 2809005dd0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, true>(uint32_t, uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu deleted file mode 100644 index 015ef21a15..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_basediff.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, false, true>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu deleted file mode 100644 index 0ac96c8440..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_fp8u_no_smem_lut.cu +++ /dev/null @@ -1,28 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel, true, false>(uint32_t, - uint32_t) - -> compute_similarity_kernel_t>; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu deleted file mode 100644 index f3501d11c0..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_fast.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu deleted file mode 100644 index 7d10020480..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_basediff.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu deleted file mode 100644 index 91ec2eca3e..0000000000 --- a/cpp/src/neighbors/specializations/detail/compute_similarity_half_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu b/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu deleted file mode 100644 index 145312f334..0000000000 --- a/cpp/src/neighbors/specializations/detail/ivfpq_compute_similarity_float_half_no_smem_lut.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -template auto get_compute_similarity_kernel(uint32_t, uint32_t) - -> compute_similarity_kernel_t; - -} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu b/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu deleted file mode 100644 index 72fdac9526..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_false.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu b/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu deleted file mode 100644 index c7616462fe..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_int_float_true.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { -template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu b/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu deleted file mode 100644 index 16bf058238..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_false.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu b/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu deleted file mode 100644 index 06cf55eae3..0000000000 --- a/cpp/src/neighbors/specializations/fused_l2_knn_long_float_true.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu deleted file mode 100644 index 7082873d76..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_float_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu deleted file mode 100644 index ebc1a7fefa..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_int8_t_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu deleted file mode 100644 index 870db6e97e..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu deleted file mode 100644 index 71af06ad71..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_float_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu deleted file mode 100644 index bb7bb6e7eb..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu deleted file mode 100644 index 607b4b0913..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const raft::neighbors::ivf_flat::index& orig_index) \ - ->raft::neighbors::ivf_flat::index; \ - \ - template void extend(raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu deleted file mode 100644 index dce7083139..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_float_int64_t.cu +++ /dev/null @@ -1,58 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan -// function is used in both raft::neighbors::ivf_flat::search and -// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation -// of this function (which defines ~270 kernels) in the refine specializations, -// an extern template definition is provided. To make sure -// ivfflat_interleaved_scan is actually compiled here, we explicitly instantiate -// it below. Please check related function calls after editing template -// definition below. Search for `greppable-id-specializations-ivf-flat-search` -// to find them. -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu deleted file mode 100644 index b03d878bae..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_int8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu deleted file mode 100644 index 2d42bae0d1..0000000000 --- a/cpp/src/neighbors/specializations/ivfflat_search_uint8_t_int64_t.cu +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors::ivf_flat { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); \ - \ - template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_flat diff --git a/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu deleted file mode 100644 index d559291b93..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_float_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu deleted file mode 100644 index c8b31e1fff..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu deleted file mode 100644 index 5fc62969f0..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu deleted file mode 100644 index 584bbfc45c..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_float_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu deleted file mode 100644 index 00311a77e4..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu deleted file mode 100644 index 11524886f0..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu +++ /dev/null @@ -1,39 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& idx) \ - ->index; \ - template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - index* idx); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu deleted file mode 100644 index 92a4d89e6b..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_float_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(float, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu deleted file mode 100644 index 62a8b48ad5..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(int8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu deleted file mode 100644 index 3bcf134a22..0000000000 --- a/cpp/src/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors::ivf_pq { - -#define RAFT_MAKE_INSTANCE(T, IdxT) \ - template void search(raft::device_resources const& handle, \ - const search_params& params, \ - const index& idx, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbors, \ - raft::device_matrix_view distances); - -RAFT_MAKE_INSTANCE(uint8_t, int64_t); - -#undef RAFT_MAKE_INSTANCE - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu deleted file mode 100644 index 0b0125459d..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_float.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu deleted file mode 100644 index d6c817b971..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_int8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu b/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu deleted file mode 100644 index 3e0ca627a6..0000000000 --- a/cpp/src/neighbors/specializations/refine_d_int64_t_uint8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu deleted file mode 100644 index 66a6bace53..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_float.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu deleted file mode 100644 index 22824b3a8e..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_int8_t.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu b/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu deleted file mode 100644 index 58dcfc87c9..0000000000 --- a/cpp/src/neighbors/specializations/refine_h_int64_t_uint8_t.cu +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu deleted file mode 100644 index 2c21d1ec64..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_int.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - long* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu deleted file mode 100644 index 7e6e7e80d0..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_long_float_uint.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - long* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu deleted file mode 100644 index e94c12d579..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_int.cu +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - int D, - float* search_items, - int n, - uint32_t* res_I, - float* res_D, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu b/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu deleted file mode 100644 index 95cf8a1eb3..0000000000 --- a/cpp/src/nn/specializations/brute_force_knn_uint32_t_float_uint.cu +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -namespace raft { -namespace spatial { -namespace knn { - -template void brute_force_knn(raft::device_resources const& handle, - std::vector& input, - std::vector& sizes, - unsigned int D, - float* search_items, - unsigned int n, - uint32_t* res_I, - float* res_D, - unsigned int k, - bool rowMajorIndex, - bool rowMajorQuery, - std::vector* translations, - distance::DistanceType metric, - float metric_arg); - -}; // namespace knn -}; // namespace spatial -}; // namespace raft diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers.cu b/cpp/src/spatial/knn/detail/ball_cover/registers.cu new file mode 100644 index 0000000000..0bb6d123a9 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers.cu @@ -0,0 +1,60 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + raft::spatial::knn::detail::DistFunc& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(std::int64_t, float, std::uint32_t, 3); + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 2); +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(std::int64_t, float, std::uint32_t, 3); + +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py new file mode 100644 index 0000000000..f8ce27728b --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_00_generate.py @@ -0,0 +1,112 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +""" + + +macro_pass_one = """ +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \\ + raft::device_resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_int k, \\ + const Mvalue_idx* R_knn_inds, \\ + const Mvalue_t* R_knn_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* inds, \\ + Mvalue_t* dists, \\ + float weight, \\ + Mvalue_int* dists_counter) + +""" + +macro_pass_two = """ +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \\ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \\ + template void \\ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \\ + raft::device_resources const& handle, \\ + const BallCoverIndex& index, \\ + const Mvalue_t* query, \\ + const Mvalue_int n_query_rows, \\ + Mvalue_int k, \\ + const Mvalue_idx* R_knn_inds, \\ + const Mvalue_t* R_knn_dists, \\ + Mdist_func& dfunc, \\ + Mvalue_idx* inds, \\ + Mvalue_t* dists, \\ + float weight, \\ + Mvalue_int* dists_counter) + +""" + +distances = dict( + haversine="raft::spatial::knn::detail::HaversineFunc", + euclidean="raft::spatial::knn::detail::EuclideanFunc", + dist="raft::spatial::knn::detail::DistFunc", +) + +for k, v in distances.items(): + for dim in [2, 3]: + path = f"registers_pass_one_{dim}d_{k}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro_pass_one) + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one(\n") + f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") + +for k, v in distances.items(): + for dim in [2, 3]: + path = f"registers_pass_two_{dim}d_{k}.cu" + with open(path, "w") as f: + f.write(header) + f.write(macro_pass_two) + f.write(f"instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two(\n") + f.write(f" std::int64_t, float, std::uint32_t, {dim}, {v});\n") + f.write("#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two\n") + print(f"src/spatial/knn/detail/ball_cover/{path}") diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu new file mode 100644 index 0000000000..b4ecac06e6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu new file mode 100644 index 0000000000..31628d8b82 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu new file mode 100644 index 0000000000..80fda1bf9d --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_2d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu new file mode 100644 index 0000000000..40aa89aa39 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu new file mode 100644 index 0000000000..be159932a6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu new file mode 100644 index 0000000000..a9fe8f355f --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_one_3d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_one( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_one diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu new file mode 100644 index 0000000000..b20df46a4f --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu new file mode 100644 index 0000000000..d5042b0142 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu new file mode 100644 index 0000000000..01002d356e --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_2d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 2, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu new file mode 100644 index 0000000000..5746ab99fb --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_dist.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::DistFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu new file mode 100644 index 0000000000..fad007a2d4 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_euclidean.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::EuclideanFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu new file mode 100644 index 0000000000..93083da5c6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/ball_cover/registers_pass_two_3d_haversine.cu @@ -0,0 +1,48 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by registers_00_generate.py + * + * Make changes there and run in this directory: + * + * > python registers_00_generate.py + * + */ + +#include // int64_t +#include + +#define instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( \ + Mvalue_idx, Mvalue_t, Mvalue_int, Mdims, Mdist_func) \ + template void \ + raft::spatial::knn::detail::rbc_low_dim_pass_two( \ + raft::device_resources const& handle, \ + const BallCoverIndex& index, \ + const Mvalue_t* query, \ + const Mvalue_int n_query_rows, \ + Mvalue_int k, \ + const Mvalue_idx* R_knn_inds, \ + const Mvalue_t* R_knn_dists, \ + Mdist_func& dfunc, \ + Mvalue_idx* inds, \ + Mvalue_t* dists, \ + float weight, \ + Mvalue_int* dists_counter) + +instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two( + std::int64_t, float, std::uint32_t, 3, raft::spatial::knn::detail::HaversineFunc); +#undef instantiate_raft_spatial_knn_detail_rbc_low_dim_pass_two diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu new file mode 100644 index 0000000000..67b08655e6 --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu new file mode 100644 index 0000000000..3c0d13710e --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu @@ -0,0 +1,40 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(int64_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn diff --git a/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu new file mode 100644 index 0000000000..e799c5181f --- /dev/null +++ b/cpp/src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu @@ -0,0 +1,41 @@ +/* + * Copyright (c) 2021-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // size_t +#include // int_Xt +#include // DistanceType +#include + +#define instantiate_raft_spatial_knn_detail_fusedL2Knn(Mvalue_idx, Mvalue_t, MusePrevTopKs) \ + template void raft::spatial::knn::detail::fusedL2Knn( \ + size_t D, \ + Mvalue_idx * out_inds, \ + Mvalue_t * out_dists, \ + const Mvalue_t* index, \ + const Mvalue_t* query, \ + size_t n_index_rows, \ + size_t n_query_rows, \ + int k, \ + bool rowMajorIndex, \ + bool rowMajorQuery, \ + cudaStream_t stream, \ + raft::distance::DistanceType metric) + +// These are used by brute_force_knn: +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, true); +instantiate_raft_spatial_knn_detail_fusedL2Knn(uint32_t, float, false); + +#undef instantiate_raft_spatial_knn_detail_fusedL2Knn From ff79abf0860321985cee7228bef0afadaf4f3d5d Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 15:55:02 +0200 Subject: [PATCH 11/32] Deprecate specialization headers --- cpp/include/raft/cluster/specializations.cuh | 12 +- cpp/include/raft/distance/specializations.cuh | 14 +- .../detail/00_write_template.py | 148 ------------------ .../specializations/detail/canberra.cuh | 40 ----- .../specializations/detail/correlation.cuh | 40 ----- .../specializations/detail/cosine.cuh | 40 ----- .../detail/hamming_unexpanded.cuh | 40 ----- .../detail/hellinger_expanded.cuh | 40 ----- .../specializations/detail/inner_product.cuh | 52 ------ .../specializations/detail/jensen_shannon.cuh | 40 ----- .../specializations/detail/kernels.cuh | 31 ---- .../specializations/detail/kl_divergence.cuh | 40 ----- .../distance/specializations/detail/l1.cuh | 40 ----- .../specializations/detail/l2_expanded.cuh | 40 ----- .../specializations/detail/l2_unexpanded.cuh | 40 ----- .../distance/specializations/detail/l_inf.cuh | 40 ----- .../specializations/detail/lp_unexpanded.cuh | 40 ----- .../specializations/detail/russel_rao.cuh | 40 ----- .../distance/specializations/distance.cuh | 22 +-- .../specializations/fused_l2_nn_min.cuh | 117 +------------- cpp/include/raft/matrix/specializations.cuh | 7 +- .../specializations/detail/select_k.cuh | 35 +---- .../raft/neighbors/specializations.cuh | 17 +- .../neighbors/specializations/ball_cover.cuh | 41 +---- .../neighbors/specializations/brute_force.cuh | 34 +--- .../detail/ivf_pq_compute_similarity.cuh | 38 +---- .../specializations/fused_l2_knn.cuh | 72 +-------- .../neighbors/specializations/ivf_flat.cuh | 65 +------- .../raft/neighbors/specializations/ivf_pq.cuh | 63 +------- .../raft/neighbors/specializations/refine.cuh | 39 +---- .../raft/sparse/neighbors/specializations.cuh | 10 +- .../raft/spatial/knn/specializations.cuh | 9 +- .../raft/spatial/knn/specializations/knn.cuh | 31 +--- cpp/include/raft/spectral/specializations.cuh | 12 +- cpp/include/raft/stats/specializations.cuh | 12 +- 35 files changed, 100 insertions(+), 1301 deletions(-) delete mode 100644 cpp/include/raft/distance/specializations/detail/00_write_template.py delete mode 100644 cpp/include/raft/distance/specializations/detail/canberra.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/correlation.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/cosine.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/inner_product.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/kernels.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/kl_divergence.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/l1.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/l2_expanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/l_inf.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh delete mode 100644 cpp/include/raft/distance/specializations/detail/russel_rao.cuh diff --git a/cpp/include/raft/cluster/specializations.cuh b/cpp/include/raft/cluster/specializations.cuh index 9b68d7adc9..14cab6b56b 100644 --- a/cpp/include/raft/cluster/specializations.cuh +++ b/cpp/include/raft/cluster/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __CLUSTER_SPECIALIZATIONS_H -#define __CLUSTER_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations.cuh b/cpp/include/raft/distance/specializations.cuh index 5944534be7..07b14d7307 100644 --- a/cpp/include/raft/distance/specializations.cuh +++ b/cpp/include/raft/distance/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - -#ifndef __DISTANCE_SPECIALIZATIONS_H -#define __DISTANCE_SPECIALIZATIONS_H - #pragma once -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/detail/00_write_template.py b/cpp/include/raft/distance/specializations/detail/00_write_template.py deleted file mode 100644 index 63ae6580b4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/00_write_template.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 - -# This template manages all files in this directory, apart from -# inner_product.cuh and kernels.cuh. - - -# NOTE: this template is not perfectly formatted. Use pre-commit to get -# everything in shape again. -start_template = """/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -""" - -extern_template = """ -extern template void pairwise_matrix_instantiation_point( - OpT, - pairwise_matrix_params, - cudaStream_t); -""" - -end_template = """} // namespace raft::distance::detail -""" - -data_type_instances = [ - dict( - DataT="float", - AccT="float", - OutT="float", - IdxT="int", - ), - dict( - DataT="double", - AccT="double", - OutT="double", - IdxT="int", - ), -] - - - - -op_instances = [ - dict( - path_prefix="canberra", - OpT="ops::canberra_distance_op", - ), - dict( - path_prefix="correlation", - OpT="ops::correlation_distance_op", - ), - dict( - path_prefix="cosine", - OpT="ops::cosine_distance_op", - # cosine uses CUTLASS for SM80+ - ), - dict( - path_prefix="hamming_unexpanded", - OpT="ops::hamming_distance_op", - ), - dict( - path_prefix="hellinger_expanded", - OpT="ops::hellinger_distance_op", - ), - # inner product is handled by cublas. - dict( - path_prefix="jensen_shannon", - OpT="ops::jensen_shannon_distance_op", - ), - dict( - path_prefix="kl_divergence", - OpT="ops::kl_divergence_op", - ), - dict( - path_prefix="l1", - OpT="ops::l1_distance_op", - ), - dict( - path_prefix="l2_expanded", - OpT="ops::l2_exp_distance_op", - # L2 expanded uses CUTLASS for SM80+ - ), - dict( - path_prefix="l2_unexpanded", - OpT="ops::l2_unexp_distance_op", - ), - dict( - path_prefix="l_inf", - OpT="ops::l_inf_distance_op", - ), - dict( - path_prefix="lp_unexpanded", - OpT="ops::lp_unexp_distance_op", - ), - dict( - path_prefix="russel_rao", - OpT="ops::russel_rao_distance_op", - ), -] - -def fill_in(s, template): - for k, v in template.items(): - s = s.replace(k, v) - return s - -for op_instance in op_instances: - path = fill_in("path_prefix.cuh", op_instance) - with open(path, "w") as f: - f.write(start_template) - - for data_type_instance in data_type_instances: - op_data_instance = { - k : fill_in(v, data_type_instance) - for k, v in op_instance.items() - } - instance = { - **op_data_instance, - **data_type_instance, - "FinopT": "raft::identity_op", - } - - text = fill_in(extern_template, instance) - - f.write(text) - - f.write(end_template) diff --git a/cpp/include/raft/distance/specializations/detail/canberra.cuh b/cpp/include/raft/distance/specializations/detail/canberra.cuh deleted file mode 100644 index 276c85e5f6..0000000000 --- a/cpp/include/raft/distance/specializations/detail/canberra.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - float, - float, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::canberra_distance_op, - int, - double, - double, - raft::identity_op>(ops::canberra_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/correlation.cuh b/cpp/include/raft/distance/specializations/detail/correlation.cuh deleted file mode 100644 index f019f678df..0000000000 --- a/cpp/include/raft/distance/specializations/detail/correlation.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - float, - float, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::correlation_distance_op, - int, - double, - double, - raft::identity_op>(ops::correlation_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/cosine.cuh b/cpp/include/raft/distance/specializations/detail/cosine.cuh deleted file mode 100644 index dcde4ec286..0000000000 --- a/cpp/include/raft/distance/specializations/detail/cosine.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::cosine_distance_op, - int, - double, - double, - raft::identity_op>(ops::cosine_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh deleted file mode 100644 index 1d6964fbce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hamming_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - float, - float, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hamming_distance_op, - int, - double, - double, - raft::identity_op>(ops::hamming_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh b/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh deleted file mode 100644 index f96a06f919..0000000000 --- a/cpp/include/raft/distance/specializations/detail/hellinger_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - float, - float, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::hellinger_distance_op, - int, - double, - double, - raft::identity_op>(ops::hellinger_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/inner_product.cuh b/cpp/include/raft/distance/specializations/detail/inner_product.cuh deleted file mode 100644 index d97d678928..0000000000 --- a/cpp/include/raft/distance/specializations/detail/inner_product.cuh +++ /dev/null @@ -1,52 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft { -namespace distance { -namespace detail { -extern template void distance( - raft::resources const& handle, - const float* x, - const float* y, - float* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - float metric_arg); - -extern template void -distance( - raft::resources const& handle, - const double* x, - const double* y, - double* dist, - int m, - int n, - int k, - void* workspace, - size_t worksize, - bool isRowMajor, - double metric_arg); -} // namespace detail -} // namespace distance -} // namespace raft diff --git a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh b/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh deleted file mode 100644 index 0b58646582..0000000000 --- a/cpp/include/raft/distance/specializations/detail/jensen_shannon.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - float, - float, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::jensen_shannon_distance_op, - int, - double, - double, - raft::identity_op>(ops::jensen_shannon_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/kernels.cuh b/cpp/include/raft/distance/specializations/detail/kernels.cuh deleted file mode 100644 index 75c9c023e8..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kernels.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -extern template class raft::distance::kernels::detail::GramMatrixBase; -extern template class raft::distance::kernels::detail::GramMatrixBase; - -extern template class raft::distance::kernels::detail::PolynomialKernel; -extern template class raft::distance::kernels::detail::PolynomialKernel; - -extern template class raft::distance::kernels::detail::TanhKernel; -extern template class raft::distance::kernels::detail::TanhKernel; - -// These are somehow missing a kernel definition which is causing a compile error -// extern template class raft::distance::kernels::detail::RBFKernel; -// extern template class raft::distance::kernels::detail::RBFKernel; \ No newline at end of file diff --git a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh b/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh deleted file mode 100644 index 5c164e0fd4..0000000000 --- a/cpp/include/raft/distance/specializations/detail/kl_divergence.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::kl_divergence_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l1.cuh b/cpp/include/raft/distance/specializations/detail/l1.cuh deleted file mode 100644 index 870627d909..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l1.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point, - int, - double, - double, - raft::identity_op>( - ops::l1_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh deleted file mode 100644 index ee3207bcce..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_expanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_exp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_exp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh deleted file mode 100644 index 1fbf57632b..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l2_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l2_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::l2_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/l_inf.cuh b/cpp/include/raft/distance/specializations/detail/l_inf.cuh deleted file mode 100644 index 388d3bf439..0000000000 --- a/cpp/include/raft/distance/specializations/detail/l_inf.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point, - int, - float, - float, - raft::identity_op>( - ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::l_inf_distance_op, - int, - double, - double, - raft::identity_op>(ops::l_inf_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh b/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh deleted file mode 100644 index d8e86ce6f2..0000000000 --- a/cpp/include/raft/distance/specializations/detail/lp_unexpanded.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - float, - float, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::lp_unexp_distance_op, - int, - double, - double, - raft::identity_op>(ops::lp_unexp_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh b/cpp/include/raft/distance/specializations/detail/russel_rao.cuh deleted file mode 100644 index 4803fb8ab0..0000000000 --- a/cpp/include/raft/distance/specializations/detail/russel_rao.cuh +++ /dev/null @@ -1,40 +0,0 @@ -/* - * Copyright (c) 2021-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -#include - -namespace raft::distance::detail { - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - float, - float, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); - -extern template void pairwise_matrix_instantiation_point< - ops::russel_rao_distance_op, - int, - double, - double, - raft::identity_op>(ops::russel_rao_distance_op, - pairwise_matrix_params, - cudaStream_t); -} // namespace raft::distance::detail diff --git a/cpp/include/raft/distance/specializations/distance.cuh b/cpp/include/raft/distance/specializations/distance.cuh index a34f696e9e..07b14d7307 100644 --- a/cpp/include/raft/distance/specializations/distance.cuh +++ b/cpp/include/raft/distance/specializations/distance.cuh @@ -13,22 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh index 88e1216635..14cab6b56b 100644 --- a/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh +++ b/cpp/include/raft/distance/specializations/fused_l2_nn_min.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,115 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include - -namespace raft { -namespace distance { - -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce, int64_t>( - raft::KeyValuePair* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(float* min, - const float* x, - const float* y, - const float* xn, - const float* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int m, - int n, - int k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); -extern template void fusedL2NNMinReduce(double* min, - const double* x, - const double* y, - const double* xn, - const double* yn, - int64_t m, - int64_t n, - int64_t k, - void* workspace, - bool sqrt, - bool initOutBuffer, - cudaStream_t stream); - -} // namespace distance -} // namespace raft \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/matrix/specializations.cuh b/cpp/include/raft/matrix/specializations.cuh index 07bdeab507..7ea4aed5c5 100644 --- a/cpp/include/raft/matrix/specializations.cuh +++ b/cpp/include/raft/matrix/specializations.cuh @@ -13,7 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/matrix/specializations/detail/select_k.cuh b/cpp/include/raft/matrix/specializations/detail/select_k.cuh index 3cb1a2d8dc..7ea4aed5c5 100644 --- a/cpp/include/raft/matrix/specializations/detail/select_k.cuh +++ b/cpp/include/raft/matrix/specializations/detail/select_k.cuh @@ -13,35 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -#include - -namespace raft::matrix::detail { - -#define RAFT_INST(T, IdxT) \ - extern template void select_k(const T*, \ - const IdxT*, \ - size_t, \ - size_t, \ - int, \ - T*, \ - IdxT*, \ - bool, \ - rmm::cuda_stream_view, \ - rmm::mr::device_memory_resource*); - -// Commonly used types -RAFT_INST(float, int64_t); -RAFT_INST(half, int64_t); - -// These instances are used in the ivf_pq::search parameterized by the internal_distance_dtype -RAFT_INST(float, uint32_t); -RAFT_INST(half, uint32_t); - -#undef RAFT_INST - -} // namespace raft::matrix::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations.cuh b/cpp/include/raft/neighbors/specializations.cuh index 9da5649ef8..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations.cuh +++ b/cpp/include/raft/neighbors/specializations.cuh @@ -13,17 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include - -#include -#include -#include - -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ball_cover.cuh b/cpp/include/raft/neighbors/specializations/ball_cover.cuh index d6a6b2e296..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations/ball_cover.cuh +++ b/cpp/include/raft/neighbors/specializations/ball_cover.cuh @@ -13,41 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include - -#include - -namespace raft::neighbors::ball_cover { -extern template class BallCoverIndex; -extern template class BallCoverIndex; - -extern template void build_index( - raft::device_resources const& handle, - BallCoverIndex& index); - -extern template void knn_query( - raft::device_resources const& handle, - const BallCoverIndex& index, - std::uint32_t k, - const float* query, - std::uint32_t n_query_pts, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -extern template void all_knn_query( - raft::device_resources const& handle, - BallCoverIndex& index, - std::uint32_t k, - std::int64_t* inds, - float* dists, - bool perform_post_filtering, - float weight); - -}; // namespace raft::neighbors::ball_cover \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/brute_force.cuh b/cpp/include/raft/neighbors/specializations/brute_force.cuh index 1337beb68a..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations/brute_force.cuh +++ b/cpp/include/raft/neighbors/specializations/brute_force.cuh @@ -13,34 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -// also define the detail api, which is used by raft::neighbors::brute_force -// (not doing the public api, since has extra template params on index_layout, matrix_index, -// search_layout etc - and isn't clear what the defaults here should be) -namespace raft::neighbors::detail { -#define RAFT_INST(IdxT, T, IntT) \ - extern template void brute_force_knn_impl(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - raft::distance::DistanceType metric, \ - float metricArg, \ - raft::identity_op); -RAFT_INST(long, float, int); -RAFT_INST(long, float, unsigned int); -RAFT_INST(uint32_t, float, int); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -} // namespace raft::neighbors::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh b/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh index f1c46b1225..14cab6b56b 100644 --- a/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh +++ b/cpp/include/raft/neighbors/specializations/detail/ivf_pq_compute_similarity.cuh @@ -13,38 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -#include - -namespace raft::neighbors::ivf_pq::detail { - -namespace { -using fp8s_t = fp_8bit<5, true>; -using fp8u_t = fp_8bit<5, false>; -} // namespace - -#define RAFT_INST(OutT, LutT) \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; \ - extern template auto get_compute_similarity_kernel(uint32_t, uint32_t) \ - ->compute_similarity_kernel_t; - -#define RAFT_INST_ALL_OUT_T(LutT) \ - RAFT_INST(float, LutT) \ - RAFT_INST(half, LutT) - -RAFT_INST_ALL_OUT_T(float) -RAFT_INST_ALL_OUT_T(half) -RAFT_INST_ALL_OUT_T(fp8s_t) -RAFT_INST_ALL_OUT_T(fp8u_t) - -#undef RAFT_INST -#undef RAFT_INST_ALL_OUT_T - -} // namespace raft::neighbors::ivf_pq::detail +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh b/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh index 916db8f0a2..07b14d7307 100644 --- a/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh +++ b/cpp/include/raft/neighbors/specializations/fused_l2_knn.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021-2022, NVIDIA CORPORATION. + * Copyright (c) 2021-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,68 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#pragma once -#include -#include - -namespace raft { -namespace spatial { -namespace knn { -namespace detail { - -extern template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - long* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -extern template void fusedL2Knn(size_t D, - int* out_inds, - float* out_dists, - const float* index, - const float* query, - size_t n_index_rows, - size_t n_query_rows, - int k, - bool rowMajorIndex, - bool rowMajorQuery, - cudaStream_t stream, - raft::distance::DistanceType metric); - -}; // namespace detail -}; // namespace knn -}; // namespace spatial -}; // namespace raft +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh index 161f3462c9..7ea4aed5c5 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_flat.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_flat.cuh @@ -13,65 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::neighbors::ivf_flat { - -// greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan -// function is used in both raft::neighbors::ivf_flat::search and -// raft::neighbors::detail::refine_device. To prevent a duplicate instantiation -// of this function (which defines ~270 kernels) in the refine specializations, -// an extern template definition is provided here. Please check related function -// calls after editing template definition below. Search for -// `greppable-id-specializations-ivf-flat-search` to find them. -#define RAFT_INST(T, IdxT) \ - extern template auto build(raft::device_resources const& handle, \ - const index_params& params, \ - raft::device_matrix_view dataset) \ - ->index; \ - \ - extern template auto extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - const index& orig_index) \ - ->index; \ - \ - extern template void extend( \ - raft::device_resources const& handle, \ - raft::device_matrix_view new_vectors, \ - std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); \ - \ - extern template void search(raft::device_resources const&, \ - raft::neighbors::ivf_flat::search_params const&, \ - const raft::neighbors::ivf_flat::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); \ - \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< \ - T, \ - typename raft::spatial::knn::detail::utils::config::value_t, \ - IdxT>(const index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream); - -RAFT_INST(float, int64_t); -RAFT_INST(int8_t, int64_t); -RAFT_INST(uint8_t, int64_t); - -#undef RAFT_INST -} // namespace raft::neighbors::ivf_flat +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 9209f5095d..14cab6b56b 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -13,63 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include -#include -#include - -namespace raft::neighbors::ivf_pq { - -#ifdef RAFT_DECL_BUILD_EXTEND -#undef RAFT_DECL_BUILD_EXTEND -#endif - -#ifdef RAFT_DECL_SEARCH -#undef RAFT_DECL_SEARCH -#endif - -// We define overloads for build and extend with void return type. This is used in the Cython -// wrappers, where exception handling is not compatible with return type that has nontrivial -// constructor. -#define RAFT_DECL_BUILD_EXTEND(T, IdxT) \ - extern template auto build(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::index_params&, \ - raft::device_matrix_view) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template auto extend(raft::device_resources const&, \ - raft::device_matrix_view, \ - std::optional>, \ - const raft::neighbors::ivf_pq::index&) \ - ->raft::neighbors::ivf_pq::index; \ - \ - extern template void extend(raft::device_resources const&, \ - raft::device_matrix_view, \ - std::optional>, \ - raft::neighbors::ivf_pq::index*); - -RAFT_DECL_BUILD_EXTEND(float, int64_t) -RAFT_DECL_BUILD_EXTEND(int8_t, int64_t) -RAFT_DECL_BUILD_EXTEND(uint8_t, int64_t) - -#undef RAFT_DECL_BUILD_EXTEND - -#define RAFT_DECL_SEARCH(T, IdxT) \ - extern template void search(raft::device_resources const&, \ - const raft::neighbors::ivf_pq::search_params&, \ - const raft::neighbors::ivf_pq::index&, \ - raft::device_matrix_view, \ - raft::device_matrix_view, \ - raft::device_matrix_view); - -RAFT_DECL_SEARCH(float, int64_t); -RAFT_DECL_SEARCH(int8_t, int64_t); -RAFT_DECL_SEARCH(uint8_t, int64_t); - -#undef RAFT_DECL_SEARCH - -} // namespace raft::neighbors::ivf_pq +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh index aef4834c9f..14cab6b56b 100644 --- a/cpp/include/raft/neighbors/specializations/refine.cuh +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -13,39 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::neighbors { - -#ifdef RAFT_INST -#undef RAFT_INST -#endif - -#define RAFT_INST(T, IdxT) \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric); \ - \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ - distance::DistanceType metric); - -RAFT_INST(float, int64_t); -RAFT_INST(uint8_t, int64_t); -RAFT_INST(int8_t, int64_t); - -#undef RAFT_INST -} // namespace raft::neighbors +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/sparse/neighbors/specializations.cuh b/cpp/include/raft/sparse/neighbors/specializations.cuh index 23ba38ccda..14cab6b56b 100644 --- a/cpp/include/raft/sparse/neighbors/specializations.cuh +++ b/cpp/include/raft/sparse/neighbors/specializations.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -13,8 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spatial/knn/specializations.cuh b/cpp/include/raft/spatial/knn/specializations.cuh index 5f0a39a61b..07b14d7307 100644 --- a/cpp/include/raft/spatial/knn/specializations.cuh +++ b/cpp/include/raft/spatial/knn/specializations.cuh @@ -13,9 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include -#include -#include +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spatial/knn/specializations/knn.cuh b/cpp/include/raft/spatial/knn/specializations/knn.cuh index e045487597..07b14d7307 100644 --- a/cpp/include/raft/spatial/knn/specializations/knn.cuh +++ b/cpp/include/raft/spatial/knn/specializations/knn.cuh @@ -13,31 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - #pragma once -#include - -namespace raft::spatial::knn { -#define RAFT_INST(IdxT, T, IntT) \ - extern template void brute_force_knn(raft::device_resources const& handle, \ - std::vector& input, \ - std::vector& sizes, \ - IntT D, \ - T* search_items, \ - IntT n, \ - IdxT* res_I, \ - T* res_D, \ - IntT k, \ - bool rowMajorIndex, \ - bool rowMajorQuery, \ - std::vector* translations, \ - distance::DistanceType metric, \ - float metric_arg); - -RAFT_INST(long, float, int); -RAFT_INST(long, float, unsigned int); -RAFT_INST(uint32_t, float, int); -RAFT_INST(uint32_t, float, unsigned int); -#undef RAFT_INST -}; // namespace raft::spatial::knn +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/spectral/specializations.cuh b/cpp/include/raft/spectral/specializations.cuh index 0ce5f0c653..14cab6b56b 100644 --- a/cpp/include/raft/spectral/specializations.cuh +++ b/cpp/include/raft/spectral/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __SPECTRAL_SPECIALIZATIONS_H -#define __SPECTRAL_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") diff --git a/cpp/include/raft/stats/specializations.cuh b/cpp/include/raft/stats/specializations.cuh index e6622469d3..14cab6b56b 100644 --- a/cpp/include/raft/stats/specializations.cuh +++ b/cpp/include/raft/stats/specializations.cuh @@ -13,12 +13,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#ifndef __STATS_SPECIALIZATIONS_H -#define __STATS_SPECIALIZATIONS_H - #pragma once -#include -#include - -#endif \ No newline at end of file +#pragma message( \ + __FILE__ \ + " is deprecated and will be removed." \ + " Including specializations is not necessary any more." \ + " For more information, see: https://docs.rapids.ai/api/raft/nightly/using_libraft.html") From c9e741342b9382a60fe78fc1a9bc11c00c15cb09 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 16:11:44 +0200 Subject: [PATCH 12/32] Add interleaved scan instances --- .../neighbors/detail/ivf_flat_search-ext.cuh | 36 +++++++++++++++++++ ...at_interleaved_scan_float_float_int64_t.cu | 36 +++++++++++++++++++ ...interleaved_scan_int8_t_int32_t_int64_t.cu | 36 +++++++++++++++++++ ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 36 +++++++++++++++++++ 4 files changed, 144 insertions(+) create mode 100644 cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu create mode 100644 cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu create mode 100644 cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index 3bb3a4308d..be71c9eb11 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -19,6 +19,7 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index #include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -35,6 +36,20 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr); +template +void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) RAFT_EXPLICIT; + } // namespace raft::neighbors::ivf_flat::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -56,3 +71,24 @@ instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); #undef instantiate_raft_neighbors_ivf_flat_detail_search + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu new file mode 100644 index 0000000000..2c34d50a8c --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu new file mode 100644 index 0000000000..77aea885fb --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu new file mode 100644 index 0000000000..57d09b7d52 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan From 0c889dc64a676287f43db60b884e26d882a2a278 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 16:50:44 +0200 Subject: [PATCH 13/32] Separate fused_l2_nn_helpers These types are not used in the ext header, but are useful to have. --- cpp/include/raft/distance/fused_l2_nn-ext.cuh | 16 ++---- cpp/include/raft/distance/fused_l2_nn-inl.cuh | 26 +--------- .../raft/distance/fused_l2_nn_helpers.cuh | 49 +++++++++++++++++++ 3 files changed, 55 insertions(+), 36 deletions(-) create mode 100644 cpp/include/raft/distance/fused_l2_nn_helpers.cuh diff --git a/cpp/include/raft/distance/fused_l2_nn-ext.cuh b/cpp/include/raft/distance/fused_l2_nn-ext.cuh index a0af04c4e8..05732c1f3f 100644 --- a/cpp/include/raft/distance/fused_l2_nn-ext.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-ext.cuh @@ -16,23 +16,17 @@ #pragma once -#include // int64_t -#include // raft::device_resources -#include // raft::KeyValuePair -#include // RAFT_EXPLICIT +#include // int64_t +#include // raft::device_resources +#include // raft::KeyValuePair +#include // include initialize and reduce operations +#include // RAFT_EXPLICIT #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY namespace raft { namespace distance { -template -void initialize(raft::device_resources const& handle, - OutT* min, - IdxT m, - DataT maxVal, - ReduceOpT redOp) RAFT_EXPLICIT; - template void fusedL2NNMinReduce(OutT* min, const DataT* x, diff --git a/cpp/include/raft/distance/fused_l2_nn-inl.cuh b/cpp/include/raft/distance/fused_l2_nn-inl.cuh index e832bcb020..698d287f87 100644 --- a/cpp/include/raft/distance/fused_l2_nn-inl.cuh +++ b/cpp/include/raft/distance/fused_l2_nn-inl.cuh @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -30,31 +31,6 @@ namespace raft { namespace distance { -/** - * \defgroup fused_l2_nn Fused 1-nearest neighbors - * @{ - */ - -template -using KVPMinReduce = detail::KVPMinReduceImpl; - -template -using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; - -template -using MinReduceOp = detail::MinReduceOpImpl; - -/** @} */ - -/** - * Initialize array using init value from reduction op - */ -template -void initialize( - raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) -{ - detail::initialize(min, m, maxVal, redOp, handle.get_stream()); -} /** * \ingroup fused_l2_nn diff --git a/cpp/include/raft/distance/fused_l2_nn_helpers.cuh b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh new file mode 100644 index 0000000000..1bcd7d8dba --- /dev/null +++ b/cpp/include/raft/distance/fused_l2_nn_helpers.cuh @@ -0,0 +1,49 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +namespace raft::distance { + +/** + * \defgroup fused_l2_nn Fused 1-nearest neighbors + * @{ + */ + +template +using KVPMinReduce = detail::KVPMinReduceImpl; + +template +using MinAndDistanceReduceOp = detail::MinAndDistanceReduceOpImpl; + +template +using MinReduceOp = detail::MinReduceOpImpl; + +/** @} */ + +/** + * Initialize array using init value from reduction op + */ +template +void initialize( + raft::device_resources const& handle, OutT* min, IdxT m, DataT maxVal, ReduceOpT redOp) +{ + detail::initialize(min, m, maxVal, redOp, handle.get_stream()); +} + +} // namespace raft::distance From f97b2a834a5cdc67894a2488c6850fd9cd09d2df Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 16:53:32 +0200 Subject: [PATCH 14/32] Remove pairwise_matrix_instantiation_point --- .../detail/pairwise_matrix/dispatch-inl.cuh | 69 +++++++++---------- 1 file changed, 31 insertions(+), 38 deletions(-) diff --git a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh index e04b56ee8a..bb4422735b 100644 --- a/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh +++ b/cpp/include/raft/distance/detail/pairwise_matrix/dispatch-inl.cuh @@ -24,7 +24,7 @@ * architectures. * * 2. Provide concise function templates that can be instantiated in - * src/distance/distance/specializations/detail/. Previously, + * src/distance/detail/pairwise_matrix/. Previously, * raft::distance::detail::distance was instantiated. The function * necessarily required a large set of include files, which slowed down the * build. The raft::distance::detail::pairwise_matrix_arch_dispatch functions @@ -40,13 +40,13 @@ // Including dispatch_sm80.cuh can slow down compile times (due to CUTLASS). // Therefore, it is the including file's responsibility to include the correct // dispatch_smXX.cuh headers, as is done in raft/distance/detail/distance.cuh -// and the specializations in src/distance/distance/specializations/detail/. +// and src/distance/detail/pairwise_matrix/dispatch_*.cu. namespace raft::distance::detail { // This forward-declaration ensures that we do not need to include // dispatch_sm80.cuh if we are not calling it in practice. This makes compiling -// all the non-CUTLASS based distance specializations faster. For CUTLASS-based +// all the non-CUTLASS based distance instantiations faster. For CUTLASS-based // distances, dispatch_sm80.cuh has to be included by the file including this // file. template -void pairwise_matrix_instantiation_point(OpT distance_op, - pairwise_matrix_params params, - cudaStream_t stream) +template +void pairwise_matrix_dispatch(OpT distance_op, + IdxT m, + IdxT n, + IdxT k, + const DataT* x, + const DataT* y, + const DataT* x_norm, + const DataT* y_norm, + OutT* out, + FinOpT fin_op, + cudaStream_t stream, + bool is_row_major) { + // Create kernel parameter struct. Flip x and y if column major. + IdxT ldx = is_row_major ? k : m; + IdxT ldy = is_row_major ? k : n; + IdxT ld_out = is_row_major ? n : m; + + pairwise_matrix_params params{ + m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; + + if (!params.is_row_major) { params.flip_x_and_y(); } + // On CUDA 12: // - always execute normal kernel // @@ -103,35 +127,4 @@ void pairwise_matrix_instantiation_point(OpT distance_op, } } -template -void pairwise_matrix_dispatch(OpT distance_op, - IdxT m, - IdxT n, - IdxT k, - const DataT* x, - const DataT* y, - const DataT* x_norm, - const DataT* y_norm, - OutT* out, - FinOpT fin_op, - cudaStream_t stream, - bool is_row_major) -{ - // Create kernel parameter struct. Flip x and y if column major. - IdxT ldx = is_row_major ? k : m; - IdxT ldy = is_row_major ? k : n; - IdxT ld_out = is_row_major ? n : m; - - pairwise_matrix_params params{ - m, n, k, ldx, ldy, ld_out, x, y, x_norm, y_norm, out, fin_op, is_row_major}; - - if (!params.is_row_major) { params.flip_x_and_y(); } - pairwise_matrix_instantiation_point(distance_op, params, stream); -} - }; // namespace raft::distance::detail From fb637f7a3cd0d78941f38f3afd67653ff2949b6c Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:01:20 +0200 Subject: [PATCH 15/32] Rename specialization => instantiation --- .../raft/neighbors/detail/ivf_flat_search-inl.cuh | 8 -------- cpp/include/raft/neighbors/detail/knn_brute_force.cuh | 2 +- cpp/include/raft/neighbors/detail/refine.cuh | 9 --------- cpp/test/neighbors/ann_ivf_pq.cuh | 2 +- 4 files changed, 2 insertions(+), 19 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index dcd7c4f9b7..526ee2b7b0 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -1066,14 +1066,6 @@ void ivfflat_interleaved_scan(const index& index, uint32_t& grid_dim_x, rmm::cuda_stream_view stream) { - // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan - // function is used in both raft::neighbors::ivf_flat::search and - // raft::neighbors::detail::refine_device. To prevent a duplicate - // instantiation of this function (which defines ~270 kernels) in the refine - // specializations, an extern template definition is provided. Please check - // related function calls after editing this function definition. Search for - // `greppable-id-specializations-ivf-flat-search` to find them. - const int capacity = bound_by_power_of_two(k); select_interleaved_scan_kernel::run(capacity, index.veclen(), diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 0148a1a887..5d099b8d67 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -142,7 +142,7 @@ void tiled_brute_force_knn(const raft::device_resources& handle, // calculate the top-k elements for the current tile, by calculating the // full pairwise distance for the tile - and then selecting the top-k from that // note: we're using a int32 IndexType here on purpose in order to - // use the pairwise_distance specializations. Since the tile size will ensure + // use the pairwise_distance instantiations. Since the tile size will ensure // that the total memory is < 1GB per tile, this will not cause any issues distance::pairwise_distance(handle, search + i * d, diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index 722a4b15f9..b85bbd0e9c 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -117,15 +117,6 @@ void refine_device(raft::device_resources const& handle, neighbor_candidates.data_handle(), n_queries, n_candidates); - - // greppable-id-specializations-ivf-flat-search: The ivfflat_interleaved_scan - // function is used in both raft::neighbors::ivf_flat::search and - // raft::neighbors::detail::refine_device. To prevent a duplicate - // instantiation of this function (which defines ~270 kernels) in the refine - // specializations, an extern template definition is provided. Please check - // and adjust the extern template definition and the instantiation when the - // below function call is edited. Search for - // `greppable-id-specializations-ivf-flat-search` to find them. uint32_t grid_dim_x = 1; raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan< data_t, diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 95f08ad3d1..5181805999 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -32,7 +32,7 @@ #ifdef RAFT_COMPILED #include #else -#pragma message("NN specializations are not enabled; expect very long building times.") +#pragma message("NN instantiations are not enabled; expect very long building times.") #endif #include From 7b065aff81a0b1976e2a9e2f3de6690361a1111b Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:32:58 +0200 Subject: [PATCH 16/32] test/neighbors/selection.cu: Expose kFaissMaxK --- .../neighbors/detail/selection_faiss-ext.cuh | 7 +++-- .../neighbors/detail/selection_faiss-inl.cuh | 8 +---- .../detail/selection_faiss_helpers.cuh | 31 +++++++++++++++++++ cpp/test/neighbors/selection.cu | 2 ++ 4 files changed, 38 insertions(+), 10 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh index 62b2b25261..7cf12251d9 100644 --- a/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss-ext.cuh @@ -16,9 +16,10 @@ #pragma once -#include // size_t -#include // uint32_t -#include // RAFT_EXPLICIT +#include // size_t +#include // uint32_t +#include // kFaissMaxK +#include // RAFT_EXPLICIT #if defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) diff --git a/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh index 5df42e94b9..d2e3206993 100644 --- a/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh +++ b/cpp/include/raft/neighbors/detail/selection_faiss-inl.cuh @@ -20,16 +20,10 @@ #include #include +#include // kFaissMaxK namespace raft::neighbors::detail { -template -constexpr int kFaissMaxK() -{ - if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } - return 2048; -} - template __global__ void select_k_kernel(const key_t* inK, const payload_t* inV, diff --git a/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh b/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh new file mode 100644 index 0000000000..c4b69f21ec --- /dev/null +++ b/cpp/include/raft/neighbors/detail/selection_faiss_helpers.cuh @@ -0,0 +1,31 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +namespace raft::neighbors::detail { + +// This function is used in cpp/test/neighbors/select.cu. We want to make it +// available through both the selection_faiss-inl.cuh and +// selection_faiss-ext.cuh headers. +template +constexpr int kFaissMaxK() +{ + if (sizeof(key_t) >= 8) { return sizeof(payload_t) >= 8 ? 512 : 1024; } + return 2048; +} + +} // namespace raft::neighbors::detail diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 9f13de357c..1b4b7cb832 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -17,6 +17,8 @@ #include #include #include +#include +#include // kFaissMax #include #include From d5b567397ca30c9e1f38e170adc52450626d97c5 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 18:29:45 +0200 Subject: [PATCH 17/32] Remove includes of specialization headers --- cpp/bench/ann/src/raft/raft_benchmark.cu | 4 ---- cpp/bench/ann/src/raft/raft_ivf_flat.cu | 6 +----- cpp/bench/ann/src/raft/raft_ivf_pq.cu | 4 ---- cpp/bench/prims/cluster/kmeans.cu | 4 ---- cpp/bench/prims/cluster/kmeans_balanced.cu | 4 ---- cpp/bench/prims/distance/distance_common.cuh | 3 --- cpp/bench/prims/distance/fused_l2_nn.cu | 3 --- cpp/bench/prims/distance/kernels.cu | 4 ---- cpp/bench/prims/distance/masked_nn.cu | 4 ---- cpp/bench/prims/matrix/select_k.cu | 4 ---- cpp/bench/prims/neighbors/knn.cuh | 4 ---- cpp/bench/prims/neighbors/refine_float_int64_t.cu | 5 ----- cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu | 4 ---- cpp/doxygen/Doxyfile | 1 + cpp/internal/raft_internal/matrix/select_k.cuh | 6 ------ cpp/internal/raft_internal/neighbors/naive_knn.cuh | 4 ---- cpp/src/raft_runtime/cluster/cluster_cost_double.cu | 1 - cpp/src/raft_runtime/cluster/cluster_cost_float.cu | 1 - cpp/src/raft_runtime/cluster/kmeans_fit_double.cu | 1 - cpp/src/raft_runtime/cluster/kmeans_fit_float.cu | 1 - .../raft_runtime/cluster/kmeans_init_plus_plus_double.cu | 1 - cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu | 1 - cpp/src/raft_runtime/cluster/update_centroids.cuh | 1 - cpp/src/raft_runtime/cluster/update_centroids_double.cu | 1 - cpp/src/raft_runtime/cluster/update_centroids_float.cu | 1 - cpp/src/raft_runtime/distance/fused_l2_min_arg.cu | 1 - cpp/src/raft_runtime/distance/pairwise_distance.cu | 1 - cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu | 1 - .../raft_runtime/neighbors/brute_force_knn_int64_t_float.cu | 2 -- cpp/src/raft_runtime/neighbors/ivf_flat_build.cu | 1 - cpp/src/raft_runtime/neighbors/ivf_flat_search.cu | 1 - cpp/src/raft_runtime/neighbors/ivfpq_build.cu | 1 - cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu | 2 +- .../raft_runtime/neighbors/ivfpq_search_float_int64_t.cu | 1 - .../raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu | 1 - .../raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu | 1 - cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu | 1 - cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu | 1 - cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu | 1 - cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu | 1 - cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu | 1 - cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu | 1 - cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu | 1 - cpp/template/src/test_distance.cu | 4 ---- cpp/test/cluster/cluster_solvers.cu | 4 ---- cpp/test/cluster/kmeans.cu | 4 ---- cpp/test/cluster/kmeans_balanced.cu | 4 ---- cpp/test/cluster/kmeans_find_k.cu | 4 ---- cpp/test/cluster/linkage.cu | 4 ---- cpp/test/distance/fused_l2_nn.cu | 4 ---- cpp/test/distance/gram.cu | 4 ---- cpp/test/distance/masked_nn.cu | 4 ---- cpp/test/matrix/select_k.cu | 4 ---- cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu | 4 ---- cpp/test/neighbors/ann_ivf_flat.cuh | 4 ---- cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu | 4 ---- cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu | 4 ---- cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu | 4 ---- cpp/test/neighbors/ann_ivf_pq.cuh | 5 ----- cpp/test/neighbors/ball_cover.cu | 4 ---- cpp/test/neighbors/epsilon_neighborhood.cu | 4 ---- cpp/test/neighbors/fused_l2_knn.cu | 4 ---- cpp/test/neighbors/knn.cu | 4 ---- cpp/test/neighbors/refine.cu | 4 ---- cpp/test/neighbors/selection.cu | 3 --- cpp/test/neighbors/tiled_knn.cu | 4 ---- cpp/test/sparse/neighbors/knn_graph.cu | 3 --- cpp/test/stats/silhouette_score.cu | 4 ---- cpp/test/stats/trustworthiness.cu | 4 ---- 69 files changed, 3 insertions(+), 193 deletions(-) diff --git a/cpp/bench/ann/src/raft/raft_benchmark.cu b/cpp/bench/ann/src/raft/raft_benchmark.cu index d8e98ce2a9..baff1b1c45 100644 --- a/cpp/bench/ann/src/raft/raft_benchmark.cu +++ b/cpp/bench/ann/src/raft/raft_benchmark.cu @@ -22,10 +22,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include "../common/ann_types.hpp" #include "../common/benchmark_util.hpp" #undef WARP_SIZE diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat.cu b/cpp/bench/ann/src/raft/raft_ivf_flat.cu index ff108080b5..bcd23723a4 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_flat.cu @@ -15,12 +15,8 @@ */ #include "raft_ivf_flat_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; template class RaftIvfFlatGpu; -} // namespace raft::bench::ann \ No newline at end of file +} // namespace raft::bench::ann diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq.cu b/cpp/bench/ann/src/raft/raft_ivf_pq.cu index 338bc9a32f..2efe14631b 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq.cu +++ b/cpp/bench/ann/src/raft/raft_ivf_pq.cu @@ -15,10 +15,6 @@ */ #include "raft_ivf_pq_wrapper.h" -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::ann { template class RaftIvfPQ; template class RaftIvfPQ; diff --git a/cpp/bench/prims/cluster/kmeans.cu b/cpp/bench/prims/cluster/kmeans.cu index af7afb8037..3147960f72 100644 --- a/cpp/bench/prims/cluster/kmeans.cu +++ b/cpp/bench/prims/cluster/kmeans.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBenchParams { diff --git a/cpp/bench/prims/cluster/kmeans_balanced.cu b/cpp/bench/prims/cluster/kmeans_balanced.cu index 6bda43bdb2..42a8f7967c 100644 --- a/cpp/bench/prims/cluster/kmeans_balanced.cu +++ b/cpp/bench/prims/cluster/kmeans_balanced.cu @@ -18,10 +18,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft::bench::cluster { struct KMeansBalancedBenchParams { diff --git a/cpp/bench/prims/distance/distance_common.cuh b/cpp/bench/prims/distance/distance_common.cuh index 9b5d67a46f..dff3401b62 100644 --- a/cpp/bench/prims/distance/distance_common.cuh +++ b/cpp/bench/prims/distance/distance_common.cuh @@ -17,9 +17,6 @@ #include #include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/fused_l2_nn.cu b/cpp/bench/prims/distance/fused_l2_nn.cu index a5115407dd..24c0cbf8f9 100644 --- a/cpp/bench/prims/distance/fused_l2_nn.cu +++ b/cpp/bench/prims/distance/fused_l2_nn.cu @@ -18,9 +18,6 @@ #include #include #include -#if defined RAFT_COMPILED -#include -#endif #include namespace raft::bench::distance { diff --git a/cpp/bench/prims/distance/kernels.cu b/cpp/bench/prims/distance/kernels.cu index 4407bdcf83..53d97c1fc7 100644 --- a/cpp/bench/prims/distance/kernels.cu +++ b/cpp/bench/prims/distance/kernels.cu @@ -13,10 +13,6 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/distance/masked_nn.cu b/cpp/bench/prims/distance/masked_nn.cu index f9f234187d..c804ecb3a1 100644 --- a/cpp/bench/prims/distance/masked_nn.cu +++ b/cpp/bench/prims/distance/masked_nn.cu @@ -30,10 +30,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::bench::distance::masked_nn { // Introduce various sparsity patterns diff --git a/cpp/bench/prims/matrix/select_k.cu b/cpp/bench/prims/matrix/select_k.cu index 870119db52..eb2b09cc4a 100644 --- a/cpp/bench/prims/matrix/select_k.cu +++ b/cpp/bench/prims/matrix/select_k.cu @@ -23,10 +23,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/bench/prims/neighbors/knn.cuh b/cpp/bench/prims/neighbors/knn.cuh index 8f0b1cb5d9..a987cdc4a2 100644 --- a/cpp/bench/prims/neighbors/knn.cuh +++ b/cpp/bench/prims/neighbors/knn.cuh @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/bench/prims/neighbors/refine_float_int64_t.cu b/cpp/bench/prims/neighbors/refine_float_int64_t.cu index 43be330e9b..bbedc1ae64 100644 --- a/cpp/bench/prims/neighbors/refine_float_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_float_int64_t.cu @@ -17,11 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu index 1d7cb8c8aa..4952361f03 100644 --- a/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu +++ b/cpp/bench/prims/neighbors/refine_uint8_t_int64_t.cu @@ -17,10 +17,6 @@ #include "refine.cuh" #include -#if defined RAFT_COMPILED -#include -#endif - using namespace raft::neighbors; namespace raft::bench::neighbors { diff --git a/cpp/doxygen/Doxyfile b/cpp/doxygen/Doxyfile index 17a1e0caca..1948169c91 100644 --- a/cpp/doxygen/Doxyfile +++ b/cpp/doxygen/Doxyfile @@ -918,6 +918,7 @@ EXCLUDE_SYMLINKS = NO # Note that the wildcards are matched against the file with absolute path, so to # exclude all test directories for example use the pattern */test/* +# TODO: remove specializations from exclude patterns when headers have been removed. EXCLUDE_PATTERNS = */detail/* \ */specializations/* \ */thirdparty/* diff --git a/cpp/internal/raft_internal/matrix/select_k.cuh b/cpp/internal/raft_internal/matrix/select_k.cuh index cd19ad9781..3d7a11e91e 100644 --- a/cpp/internal/raft_internal/matrix/select_k.cuh +++ b/cpp/internal/raft_internal/matrix/select_k.cuh @@ -21,12 +21,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - -#include - namespace raft::matrix::select { struct params { diff --git a/cpp/internal/raft_internal/neighbors/naive_knn.cuh b/cpp/internal/raft_internal/neighbors/naive_knn.cuh index 47d6f068e3..3ad055272b 100644 --- a/cpp/internal/raft_internal/neighbors/naive_knn.cuh +++ b/cpp/internal/raft_internal/neighbors/naive_knn.cuh @@ -21,10 +21,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/src/raft_runtime/cluster/cluster_cost_double.cu b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu index 2244ba4ed3..b6df92c839 100644 --- a/cpp/src/raft_runtime/cluster/cluster_cost_double.cu +++ b/cpp/src/raft_runtime/cluster/cluster_cost_double.cu @@ -15,7 +15,6 @@ */ #include "cluster_cost.cuh" -#include #include #include diff --git a/cpp/src/raft_runtime/cluster/cluster_cost_float.cu b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu index 4164265b55..2c26b69984 100644 --- a/cpp/src/raft_runtime/cluster/cluster_cost_float.cu +++ b/cpp/src/raft_runtime/cluster/cluster_cost_float.cu @@ -15,7 +15,6 @@ */ #include "cluster_cost.cuh" -#include #include #include diff --git a/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu index 12f4fba318..0b8b458042 100644 --- a/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_fit_double.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu index 48505dcc3e..a2831c2cf0 100644 --- a/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_fit_float.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu index 5bb0835595..d2ec26f882 100644 --- a/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_double.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu index f211afd06e..bacab3b7d6 100644 --- a/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu +++ b/cpp/src/raft_runtime/cluster/kmeans_init_plus_plus_float.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::cluster::kmeans { diff --git a/cpp/src/raft_runtime/cluster/update_centroids.cuh b/cpp/src/raft_runtime/cluster/update_centroids.cuh index 7c13252384..de219329df 100644 --- a/cpp/src/raft_runtime/cluster/update_centroids.cuh +++ b/cpp/src/raft_runtime/cluster/update_centroids.cuh @@ -15,7 +15,6 @@ */ #include -#include #include #include #include diff --git a/cpp/src/raft_runtime/cluster/update_centroids_double.cu b/cpp/src/raft_runtime/cluster/update_centroids_double.cu index 0f38c7dd53..d967c503ff 100644 --- a/cpp/src/raft_runtime/cluster/update_centroids_double.cu +++ b/cpp/src/raft_runtime/cluster/update_centroids_double.cu @@ -15,7 +15,6 @@ */ #include "update_centroids.cuh" -#include #include #include diff --git a/cpp/src/raft_runtime/cluster/update_centroids_float.cu b/cpp/src/raft_runtime/cluster/update_centroids_float.cu index 8f0e79b438..b141a1ef20 100644 --- a/cpp/src/raft_runtime/cluster/update_centroids_float.cu +++ b/cpp/src/raft_runtime/cluster/update_centroids_float.cu @@ -15,7 +15,6 @@ */ #include "update_centroids.cuh" -#include #include #include diff --git a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu index 487c7e3a4a..bec71ae698 100644 --- a/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu +++ b/cpp/src/raft_runtime/distance/fused_l2_min_arg.cu @@ -19,7 +19,6 @@ #include #include #include -#include #include #include #include diff --git a/cpp/src/raft_runtime/distance/pairwise_distance.cu b/cpp/src/raft_runtime/distance/pairwise_distance.cu index dfdfa553e9..62597a4799 100644 --- a/cpp/src/raft_runtime/distance/pairwise_distance.cu +++ b/cpp/src/raft_runtime/distance/pairwise_distance.cu @@ -17,7 +17,6 @@ #include #include #include -#include namespace raft::runtime::distance { diff --git a/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu index 309ac50c6b..8814a8aafc 100644 --- a/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu +++ b/cpp/src/raft_runtime/matrix/select_k_float_int64_t.cu @@ -17,7 +17,6 @@ #include #include #include -#include #include diff --git a/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu index 88545b3607..ea6002eab0 100644 --- a/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu @@ -18,8 +18,6 @@ #include #include -#include - #include #include diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu index 79743662bd..48a40ab56e 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_build.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::neighbors::ivf_flat { diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu index 27f9ab0c5b..eefc7f2932 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_search.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::neighbors::ivf_flat { diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_build.cu b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu index 7f91e34969..5bfb546060 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_build.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_build.cu @@ -15,7 +15,6 @@ */ #include -#include #include namespace raft::runtime::neighbors::ivf_pq { diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu index 76a1839084..45b731fdcf 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_deserialize.cu @@ -16,7 +16,7 @@ #include #include -#include + #include namespace raft::runtime::neighbors::ivf_pq { diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu index 91093d3a39..d55d726671 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu @@ -15,7 +15,6 @@ */ #include -#include #include diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu index e1552c0b27..b73cbc0751 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -15,7 +15,6 @@ */ #include -#include #include diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu index 85195a7551..2b3dfe585d 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -15,7 +15,6 @@ */ #include -#include #include diff --git a/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu index 7a10fc6caa..21bd221c45 100644 --- a/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivfpq_serialize.cu @@ -16,7 +16,6 @@ #include #include -#include #include diff --git a/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu index 8ad8f9e8f1..79cec55294 100644 --- a/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_float.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu index 817369ed6a..f8a7a8c9c8 100644 --- a/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_int8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu index fb426b2c02..8f68f9f88e 100644 --- a/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_d_int64_t_uint8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu index 1f950dc3b6..7f19d44700 100644 --- a/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_float.cu @@ -16,7 +16,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu index da99df3618..bd21c6b198 100644 --- a/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_int8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu index 990754b033..f10d01cc09 100644 --- a/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu +++ b/cpp/src/raft_runtime/neighbors/refine_h_int64_t_uint8_t.cu @@ -15,7 +15,6 @@ */ #include -#include namespace raft::runtime::neighbors { diff --git a/cpp/template/src/test_distance.cu b/cpp/template/src/test_distance.cu index b86dde70e5..e165cd8f14 100644 --- a/cpp/template/src/test_distance.cu +++ b/cpp/template/src/test_distance.cu @@ -20,10 +20,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - int main() { raft::device_resources handle; diff --git a/cpp/test/cluster/cluster_solvers.cu b/cpp/test/cluster/cluster_solvers.cu index f26c598a2b..60e5f62dc0 100644 --- a/cpp/test/cluster/cluster_solvers.cu +++ b/cpp/test/cluster/cluster_solvers.cu @@ -19,10 +19,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/cluster/kmeans.cu b/cpp/test/cluster/kmeans.cu index cfec84256b..20110eed11 100644 --- a/cpp/test/cluster/kmeans.cu +++ b/cpp/test/cluster/kmeans.cu @@ -29,10 +29,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { template diff --git a/cpp/test/cluster/kmeans_balanced.cu b/cpp/test/cluster/kmeans_balanced.cu index 220eba4186..a34f2f3b59 100644 --- a/cpp/test/cluster/kmeans_balanced.cu +++ b/cpp/test/cluster/kmeans_balanced.cu @@ -30,10 +30,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - /* This test takes advantage of the fact that make_blobs generates balanced clusters. * It doesn't currently test whether the algorithm can make balanced clusters with an imbalanced * dataset. diff --git a/cpp/test/cluster/kmeans_find_k.cu b/cpp/test/cluster/kmeans_find_k.cu index a865651f56..bb41d4fafc 100644 --- a/cpp/test/cluster/kmeans_find_k.cu +++ b/cpp/test/cluster/kmeans_find_k.cu @@ -25,10 +25,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { template diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index 4946d52f26..fc69e0fe6f 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/distance/fused_l2_nn.cu b/cpp/test/distance/fused_l2_nn.cu index 383ad39319..c4ccd55f69 100644 --- a/cpp/test/distance/fused_l2_nn.cu +++ b/cpp/test/distance/fused_l2_nn.cu @@ -24,10 +24,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - namespace raft { namespace distance { diff --git a/cpp/test/distance/gram.cu b/cpp/test/distance/gram.cu index f99d02dc7f..32a7493930 100644 --- a/cpp/test/distance/gram.cu +++ b/cpp/test/distance/gram.cu @@ -14,10 +14,6 @@ * limitations under the License. */ -#if defined RAFT_COMPILED -#include -#endif - #include "../test_utils.cuh" #include #include diff --git a/cpp/test/distance/masked_nn.cu b/cpp/test/distance/masked_nn.cu index d01911206b..66d5a77dbf 100644 --- a/cpp/test/distance/masked_nn.cu +++ b/cpp/test/distance/masked_nn.cu @@ -28,10 +28,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft::distance::masked_nn { // The adjacency pattern determines what distances get computed. diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 2a40d70abc..857305b882 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -18,10 +18,6 @@ #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu index 71a83e2cca..1497a515d2 100644 --- a/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_cagra/test_float_uint32_t.cu @@ -18,10 +18,6 @@ #include "../ann_cagra.cuh" -// #if defined RAFT_DISTANCE_COMPILED -// #include -// #endif - namespace raft::neighbors::experimental::cagra { typedef AnnCagraTest AnnCagraTestF; diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index fe6f9163a0..4d90c3d7e4 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -36,10 +36,6 @@ #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu index e430af89df..f0988ca988 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_float_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF; diff --git a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu index e4e7a207fb..2f542bd6ec 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_int8_t_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_int8; diff --git a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu index ef7980401a..7659707089 100644 --- a/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu @@ -18,10 +18,6 @@ #include "../ann_ivf_flat.cuh" -#if defined RAFT_COMPILED -#include -#endif - namespace raft::neighbors::ivf_flat { typedef AnnIVFFlatTest AnnIVFFlatTestF_uint8; diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 5181805999..90c66ace06 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -29,11 +29,6 @@ #include #include #include -#ifdef RAFT_COMPILED -#include -#else -#pragma message("NN instantiations are not enabled; expect very long building times.") -#endif #include #include diff --git a/cpp/test/neighbors/ball_cover.cu b/cpp/test/neighbors/ball_cover.cu index 46ef3a9150..19935154df 100644 --- a/cpp/test/neighbors/ball_cover.cu +++ b/cpp/test/neighbors/ball_cover.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/epsilon_neighborhood.cu b/cpp/test/neighbors/epsilon_neighborhood.cu index 769cb7ec2d..c78a15dd2d 100644 --- a/cpp/test/neighbors/epsilon_neighborhood.cu +++ b/cpp/test/neighbors/epsilon_neighborhood.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - namespace raft { namespace spatial { namespace knn { diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index ab05b41cc9..04a6ff0578 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -23,10 +23,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/knn.cu b/cpp/test/neighbors/knn.cu index bcd4b9cb0b..e0f2c2e58e 100644 --- a/cpp/test/neighbors/knn.cu +++ b/cpp/test/neighbors/knn.cu @@ -21,10 +21,6 @@ #include #include -#ifdef RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index dd3491673e..d868ba06cf 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -31,10 +31,6 @@ #include -#if defined RAFT_COMPILED -#include -#endif - #include namespace raft::neighbors { diff --git a/cpp/test/neighbors/selection.cu b/cpp/test/neighbors/selection.cu index 1b4b7cb832..a21ff9f99e 100644 --- a/cpp/test/neighbors/selection.cu +++ b/cpp/test/neighbors/selection.cu @@ -26,9 +26,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif namespace raft::spatial::selection { diff --git a/cpp/test/neighbors/tiled_knn.cu b/cpp/test/neighbors/tiled_knn.cu index 03ddba1c64..aa46fc29f1 100644 --- a/cpp/test/neighbors/tiled_knn.cu +++ b/cpp/test/neighbors/tiled_knn.cu @@ -28,10 +28,6 @@ #include // raft::neighbors::detail::brute_force_knn_impl #include // raft::neighbors::detail::select_k -#if defined RAFT_COMPILED -#include -#endif - #include #include diff --git a/cpp/test/sparse/neighbors/knn_graph.cu b/cpp/test/sparse/neighbors/knn_graph.cu index 8873445c37..aadb00879b 100644 --- a/cpp/test/sparse/neighbors/knn_graph.cu +++ b/cpp/test/sparse/neighbors/knn_graph.cu @@ -22,9 +22,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif #include diff --git a/cpp/test/stats/silhouette_score.cu b/cpp/test/stats/silhouette_score.cu index 40b7e59d81..9ad89d59c0 100644 --- a/cpp/test/stats/silhouette_score.cu +++ b/cpp/test/stats/silhouette_score.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include #include diff --git a/cpp/test/stats/trustworthiness.cu b/cpp/test/stats/trustworthiness.cu index 2fde6b29c1..15b27c7669 100644 --- a/cpp/test/stats/trustworthiness.cu +++ b/cpp/test/stats/trustworthiness.cu @@ -20,10 +20,6 @@ #include #include -#if defined RAFT_COMPILED -#include -#endif - #include #include From 94d8117dab13b4b04ad50ae8333a6d6c44667430 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:12:16 +0200 Subject: [PATCH 18/32] test/distance/dist_adj.cu: Add instance --- cpp/test/CMakeLists.txt | 1 + cpp/test/distance/dist_adj.cu | 16 +---- cpp/test/distance/dist_adj.cuh | 71 +++++++++++++++++++ .../distance/dist_adj_distance_instance.cu | 63 ++++++++++++++++ cpp/test/distance/dist_adj_threshold.cuh | 36 ++++++++++ cpp/test/distance/distance_base.cuh | 33 ++------- 6 files changed, 178 insertions(+), 42 deletions(-) create mode 100644 cpp/test/distance/dist_adj.cuh create mode 100644 cpp/test/distance/dist_adj_distance_instance.cu create mode 100644 cpp/test/distance/dist_adj_threshold.cuh diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 22e8a9d73c..7dafb430fe 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -119,6 +119,7 @@ if(BUILD_TESTS) DISTANCE_TEST PATH test/distance/dist_adj.cu + test/distance/dist_adj_distance_instance.cu test/distance/dist_canberra.cu test/distance/dist_correlation.cu test/distance/dist_cos.cu diff --git a/cpp/test/distance/dist_adj.cu b/cpp/test/distance/dist_adj.cu index ce802e5138..bb63cc9be3 100644 --- a/cpp/test/distance/dist_adj.cu +++ b/cpp/test/distance/dist_adj.cu @@ -22,6 +22,8 @@ #include #include +#include "dist_adj.cuh" + namespace raft { namespace distance { @@ -74,18 +76,6 @@ struct DistanceAdjInputs { unsigned long long int seed; }; -template -struct threshold_final_op { - DataT threshold_val; - - __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} - __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} - __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept - { - return d_val <= threshold_val; - } -}; - template ::std::ostream& operator<<(::std::ostream& os, const DistanceAdjInputs& dims) { @@ -140,7 +130,7 @@ class DistanceAdjTest : public ::testing::TestWithParam + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + extern template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + float, + float, + uint8_t, + raft::distance::threshold_float, + int); + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + double, + double, + uint8_t, + raft::distance::threshold_double, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + extern template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_distance_instance.cu b/cpp/test/distance/dist_adj_distance_instance.cu new file mode 100644 index 0000000000..d4685d8095 --- /dev/null +++ b/cpp/test/distance/dist_adj_distance_instance.cu @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + +#include "dist_adj_threshold.cuh" +#include +#include + +#define instantiate_raft_distance_distance(DT, DataT, AccT, OutT, FinalLambda, IdxT) \ + template void raft::distance::distance( \ + raft::resources const& handle, \ + const DataT* x, \ + const DataT* y, \ + OutT* dist, \ + IdxT m, \ + IdxT n, \ + IdxT k, \ + void* workspace, \ + size_t worksize, \ + FinalLambda fin_op, \ + bool isRowMajor, \ + DataT metric_arg) + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + float, + float, + uint8_t, + raft::distance::threshold_float, + int); + +instantiate_raft_distance_distance(raft::distance::DistanceType::L2Expanded, + double, + double, + uint8_t, + raft::distance::threshold_double, + int); + +#undef instantiate_raft_distance_distance + +#define instantiate_raft_distance_getWorkspaceSize(DistT, DataT, AccT, OutT, IdxT) \ + template size_t raft::distance::getWorkspaceSize( \ + const DataT* x, const DataT* y, IdxT m, IdxT n, IdxT k) + +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, float, float, uint8_t, int); +instantiate_raft_distance_getWorkspaceSize( + raft::distance::DistanceType::L2Expanded, double, double, uint8_t, int); + +#undef instantiate_raft_distance_getWorkspaceSize diff --git a/cpp/test/distance/dist_adj_threshold.cuh b/cpp/test/distance/dist_adj_threshold.cuh new file mode 100644 index 0000000000..78663b3cd1 --- /dev/null +++ b/cpp/test/distance/dist_adj_threshold.cuh @@ -0,0 +1,36 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include // uint8_t + +namespace raft::distance { + +template +struct threshold_final_op { + DataT threshold_val; + + __device__ __host__ threshold_final_op() noexcept : threshold_val(0.0) {} + __device__ __host__ threshold_final_op(DataT val) noexcept : threshold_val(val) {} + __device__ __host__ OutT operator()(AccT d_val, Index g_idx) const noexcept + { + return d_val <= threshold_val; + } +}; + +using threshold_float = threshold_final_op; +using threshold_double = threshold_final_op; + +} // namespace raft::distance diff --git a/cpp/test/distance/distance_base.cuh b/cpp/test/distance/distance_base.cuh index 438e212fbd..45c2685001 100644 --- a/cpp/test/distance/distance_base.cuh +++ b/cpp/test/distance/distance_base.cuh @@ -18,23 +18,14 @@ #include #include // common::nvtx::range -#include // make_device_matrix_view -#include // raft::device_resources -#include // raft::sqrt +#include // make_device_matrix_view +#include // raft::device_resources +#include // raft::sqrt +#include #include // raft::distance::DistanceType #include #include // rmm::device_uvector -// When the distance library is precompiled, include only the raft_runtime -// headers. This way, a small change in one of the kernel internals does not -// trigger a rebuild of the test files (it of course still triggers a rebuild of -// the raft specializations) -#if defined RAFT_COMPILED -#include -#else -#include -#endif - namespace raft { namespace distance { @@ -449,23 +440,12 @@ void distanceLauncher(raft::device_resources const& handle, DataType threshold, DataType metric_arg = 2.0f) { -#if defined RAFT_COMPILED - // TODO: Implement and use mdspan-based - // raft::runtime::distance::pairwise_distance here. - // - // Context: - // https://github.com/rapidsai/raft/issues/1338 - bool row_major = layout_to_row_major(); - raft::runtime::distance::pairwise_distance( - handle, x, y, dist, m, n, k, distanceType, row_major, metric_arg); -#else auto x_v = make_device_matrix_view(x, m, k); auto y_v = make_device_matrix_view(y, n, k); auto dist_v = make_device_matrix_view(dist, m, n); raft::distance::distance( handle, x_v, y_v, dist_v, metric_arg); -#endif } template @@ -573,13 +553,8 @@ class BigMatrixDistanceTest : public ::testing::Test { float metric_arg); constexpr bool row_major = true; constexpr float metric_arg = 0.0f; -#if defined RAFT_COMPILED - raft::runtime::distance::pairwise_distance( - handle, x.data(), x.data(), dist.data(), m, n, k, distanceType, row_major, metric_arg); -#else raft::distance::distance( handle, x.data(), x.data(), dist.data(), m, n, k, row_major, metric_arg); -#endif RAFT_CUDA_TRY(cudaStreamSynchronize(handle.get_stream())); } From 361570b7e3a41f8fa7e00a09876ea121d21b31e1 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:12:54 +0200 Subject: [PATCH 19/32] test/cluster/linkage.cu: Allow instance --- cpp/test/cluster/linkage.cu | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cpp/test/cluster/linkage.cu b/cpp/test/cluster/linkage.cu index fc69e0fe6f..b2b177dde6 100644 --- a/cpp/test/cluster/linkage.cu +++ b/cpp/test/cluster/linkage.cu @@ -14,6 +14,15 @@ * limitations under the License. */ +// XXX: We allow the instantiation of fused_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); +// raft::linkage::connect_components( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include "../test_utils.cuh" #include From 5171de333cf4b406c3d572f6d24d00c3f50fa4cd Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:33:27 +0200 Subject: [PATCH 20/32] test/sparse/neighbors/connect_components.cu: Allow instance --- cpp/test/sparse/neighbors/connect_components.cu | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/cpp/test/sparse/neighbors/connect_components.cu b/cpp/test/sparse/neighbors/connect_components.cu index d200744329..e14cd9a180 100644 --- a/cpp/test/sparse/neighbors/connect_components.cu +++ b/cpp/test/sparse/neighbors/connect_components.cu @@ -14,6 +14,15 @@ * limitations under the License. */ +// XXX: We allow the instantiation of fused_l2_nn here: +// raft::linkage::FixConnectivitiesRedOp red_op(colors.data(), params.n_row); +// raft::linkage::connect_components( +// handle, out_edges, data.data(), colors.data(), params.n_row, params.n_col, red_op); +// +// TODO: consider adding this to libraft.so or creating an instance in a +// separate translation unit for this test. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include #include From 4426c5001dda807224348608b20fc838f7e0a064 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:34:07 +0200 Subject: [PATCH 21/32] test/neighbors/ann_ivf_pq/test_float_uint32_t.cu: Allow instance --- cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu index c14afe4d70..3d362a5261 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_uint32_t.cu @@ -14,6 +14,13 @@ * limitations under the License. */ +// XXX: the uint32_t instance is not compiled in libraft.so. So we allow +// instantiating the template here. +// +// TODO: consider removing this test or consider adding an instantiation to the +// library. +#undef RAFT_EXPLICIT_INSTANTIATE_ONLY + #include "../ann_ivf_pq.cuh" namespace raft::neighbors::ivf_pq { From e527efc56e407d7208ae6bfb4ae2fb7bbafed2f7 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:34:59 +0200 Subject: [PATCH 22/32] test/matrix/select_k.cu: Change index type --- cpp/test/matrix/select_k.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/cpp/test/matrix/select_k.cu b/cpp/test/matrix/select_k.cu index 857305b882..cbee243c92 100644 --- a/cpp/test/matrix/select_k.cu +++ b/cpp/test/matrix/select_k.cu @@ -228,9 +228,9 @@ struct SelectK // NOLINT auto& in_dists = ref.get_in_dists(); auto compare_ids = [&in_ids, &in_dists](const IdxT& i, const IdxT& j) { if (i == j) return true; - auto ix_i = uint64_t(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); - auto ix_j = uint64_t(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); - if (ix_i >= in_ids.size() || ix_j >= in_ids.size()) return false; + auto ix_i = int64_t(std::find(in_ids.begin(), in_ids.end(), i) - in_ids.begin()); + auto ix_j = int64_t(std::find(in_ids.begin(), in_ids.end(), j) - in_ids.begin()); + if (size_t(ix_i) >= in_ids.size() || size_t(ix_j) >= in_ids.size()) return false; auto dist_i = in_dists[ix_i]; auto dist_j = in_dists[ix_j]; if (dist_i == dist_j) return true; @@ -430,7 +430,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kWarpDistributedShm))); using ReferencedRandomDoubleSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomDoubleSizeT, Run) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P( // NOLINT SelectK, @@ -457,7 +457,7 @@ INSTANTIATE_TEST_CASE_P( // NOLINT select::Algo::kRadix11bitsExtraPass))); using ReferencedRandomFloatSizeT = - SelectK::params_random>; + SelectK::params_random>; TEST_P(ReferencedRandomFloatSizeT, LargeK) { run(); } // NOLINT INSTANTIATE_TEST_CASE_P(SelectK, // NOLINT ReferencedRandomFloatSizeT, From 17902e95d700426b355538ec408fc1da77f1451e Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 17:58:44 +0200 Subject: [PATCH 23/32] test/neighbors/fused_l2_knn.cu: Change index type --- cpp/test/neighbors/fused_l2_knn.cu | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/cpp/test/neighbors/fused_l2_knn.cu b/cpp/test/neighbors/fused_l2_knn.cu index 04a6ff0578..9fbccf681d 100644 --- a/cpp/test/neighbors/fused_l2_knn.cu +++ b/cpp/test/neighbors/fused_l2_knn.cu @@ -77,9 +77,9 @@ class FusedL2KNNTest : public ::testing::TestWithParam { rmm::device_uvector temp_distances(num_db_vecs * num_queries, stream_); distance::pairwise_distance( handle_, - raft::make_device_matrix_view(search_queries.data(), num_queries, dim), - raft::make_device_matrix_view(database.data(), num_db_vecs, dim), - raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), + raft::make_device_matrix_view(search_queries.data(), num_queries, dim), + raft::make_device_matrix_view(database.data(), num_db_vecs, dim), + raft::make_device_matrix_view(temp_distances.data(), num_queries, num_db_vecs), metric); spatial::knn::select_k(temp_distances.data(), From 976189b48447bb3c76ef12159d48e61c203c5770 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 18:50:58 +0200 Subject: [PATCH 24/32] Force explicit instantiations in tests --- cpp/test/CMakeLists.txt | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 7dafb430fe..78cf4af99a 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureTest) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -59,6 +59,10 @@ function(ConfigureTest) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${TEST_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories(${TEST_NAME} PUBLIC "$") install( @@ -88,6 +92,7 @@ if(BUILD_TESTS) test/cluster/kmeans_find_k.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -112,6 +117,9 @@ if(BUILD_TESTS) test/core/span.cu test/core/temporary_device_buffer.cu test/test.cpp + OPTIONAL + LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -141,6 +149,7 @@ if(BUILD_TESTS) test/distance/gram.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest(NAME LABEL_TEST PATH test/label/label.cu test/label/merge_labels.cu) @@ -202,6 +211,7 @@ if(BUILD_TESTS) test/sparse/spectral_matrix.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -221,7 +231,7 @@ if(BUILD_TESTS) ConfigureTest( NAME SOLVERS_TEST PATH test/cluster/cluster_solvers_deprecated.cu test/linalg/eigen_solvers.cu - test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB + test/lap/lap.cu test/sparse/mst.cu OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -246,11 +256,19 @@ if(BUILD_TESTS) ConfigureTest( NAME SPARSE_DIST_TEST PATH test/sparse/dist_coo_spmv.cu test/sparse/distance.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( - NAME SPARSE_NEIGHBORS_TEST PATH test/sparse/neighbors/connect_components.cu - test/sparse/neighbors/brute_force.cu test/sparse/neighbors/knn_graph.cu OPTIONAL LIB + NAME + SPARSE_NEIGHBORS_TEST + PATH + test/sparse/neighbors/connect_components.cu + test/sparse/neighbors/brute_force.cu + test/sparse/neighbors/knn_graph.cu + OPTIONAL + LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -276,6 +294,7 @@ if(BUILD_TESTS) test/neighbors/selection.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( @@ -309,6 +328,7 @@ if(BUILD_TESTS) test/stats/v_measure.cu OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureTest( From bdae61d76172bf53b2d2f94431076044fcd493d8 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 18:46:34 +0200 Subject: [PATCH 25/32] Force explicit instantiations in benchmarks --- cpp/bench/prims/CMakeLists.txt | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/cpp/bench/prims/CMakeLists.txt b/cpp/bench/prims/CMakeLists.txt index f6499623dd..505ca32886 100644 --- a/cpp/bench/prims/CMakeLists.txt +++ b/cpp/bench/prims/CMakeLists.txt @@ -17,7 +17,7 @@ function(ConfigureBench) - set(options OPTIONAL LIB) + set(options OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY) set(oneValueArgs NAME) set(multiValueArgs PATH TARGETS CONFIGURATIONS) @@ -55,6 +55,10 @@ function(ConfigureBench) "$<$:${RAFT_CUDA_FLAGS}>" ) + if(ConfigureTest_EXPLICIT_INSTANTIATE_ONLY) + target_compile_definitions(${BENCH_NAME} PRIVATE "RAFT_EXPLICIT_INSTANTIATE_ONLY") + endif() + target_include_directories( ${BENCH_NAME} PUBLIC "$" ) @@ -71,7 +75,7 @@ endfunction() if(BUILD_PRIMS_BENCH) ConfigureBench( NAME CLUSTER_BENCH PATH bench/prims/cluster/kmeans_balanced.cu bench/prims/cluster/kmeans.cu - bench/prims/main.cpp OPTIONAL LIB + bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -93,6 +97,7 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -112,7 +117,7 @@ if(BUILD_PRIMS_BENCH) ConfigureBench( NAME MATRIX_BENCH PATH bench/prims/matrix/argmin.cu bench/prims/matrix/gather.cu - bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB + bench/prims/matrix/select_k.cu bench/prims/main.cpp OPTIONAL LIB EXPLICIT_INSTANTIATE_ONLY ) ConfigureBench( @@ -139,5 +144,6 @@ if(BUILD_PRIMS_BENCH) bench/prims/main.cpp OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY ) endif() From b0b8fe5293964c185916cd783930464082e3c8d0 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 18:51:53 +0200 Subject: [PATCH 26/32] Test that headers are free standing Under multiple combinations of RAFT_EXPLICIT_INSTANTIATE_ONLY and RAFT_COMPILED --- cpp/test/CMakeLists.txt | 33 ++++++++ cpp/test/ext_headers/00_generate.py | 75 +++++++++++++++++++ ...istance_detail_pairwise_matrix_dispatch.cu | 27 +++++++ .../ext_headers/raft_distance_distance.cu | 27 +++++++ .../ext_headers/raft_distance_fused_l2_nn.cu | 27 +++++++ .../raft_linalg_detail_coalesced_reduction.cu | 27 +++++++ .../raft_matrix_detail_select_k.cu | 27 +++++++ .../ext_headers/raft_neighbors_ball_cover.cu | 27 +++++++ .../ext_headers/raft_neighbors_brute_force.cu | 27 +++++++ .../raft_neighbors_detail_ivf_flat_search.cu | 27 +++++++ .../raft_neighbors_detail_selection_faiss.cu | 27 +++++++ .../ext_headers/raft_neighbors_ivf_flat.cu | 27 +++++++ cpp/test/ext_headers/raft_neighbors_ivf_pq.cu | 27 +++++++ cpp/test/ext_headers/raft_neighbors_refine.cu | 27 +++++++ ...spatial_knn_detail_ball_cover_registers.cu | 27 +++++++ .../raft_spatial_knn_detail_fused_l2_knn.cu | 27 +++++++ 16 files changed, 486 insertions(+) create mode 100644 cpp/test/ext_headers/00_generate.py create mode 100644 cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu create mode 100644 cpp/test/ext_headers/raft_distance_distance.cu create mode 100644 cpp/test/ext_headers/raft_distance_fused_l2_nn.cu create mode 100644 cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu create mode 100644 cpp/test/ext_headers/raft_matrix_detail_select_k.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_ball_cover.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_brute_force.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_ivf_flat.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_ivf_pq.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_refine.cu create mode 100644 cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu create mode 100644 cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 78cf4af99a..0772640d3f 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -152,6 +152,39 @@ if(BUILD_TESTS) EXPLICIT_INSTANTIATE_ONLY ) + list( + APPEND + EXT_HEADER_TEST_SOURCES + test/ext_headers/raft_neighbors_brute_force.cu + test/ext_headers/raft_distance_distance.cu + test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu + test/ext_headers/raft_matrix_detail_select_k.cu + test/ext_headers/raft_neighbors_ball_cover.cu + test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu + test/ext_headers/raft_distance_fused_l2_nn.cu + test/ext_headers/raft_neighbors_ivf_pq.cu + test/ext_headers/raft_neighbors_ivf_flat.cu + test/ext_headers/raft_neighbors_refine.cu + test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu + test/ext_headers/raft_neighbors_detail_selection_faiss.cu + test/ext_headers/raft_linalg_detail_coalesced_reduction.cu + test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu + ) + + # Test that the split headers compile in isolation with: + # + # * EXT_HEADERS_TEST_COMPILED_EXPLICIT: RAFT_COMPILED, RAFT_EXPLICIT_INSTANTIATE_ONLY defined + # * EXT_HEADERS_TEST_COMPILED_IMPLICIT: RAFT_COMPILED defined + # * EXT_HEADERS_TEST_IMPLICIT: no macros defined. + ConfigureTest( + NAME EXT_HEADERS_TEST_COMPILED_EXPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB + EXPLICIT_INSTANTIATE_ONLY + ) + ConfigureTest( + NAME EXT_HEADERS_TEST_COMPILED_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES} OPTIONAL LIB + ) + ConfigureTest(NAME EXT_HEADERS_TEST_IMPLICIT PATH ${EXT_HEADER_TEST_SOURCES}) + ConfigureTest(NAME LABEL_TEST PATH test/label/label.cu test/label/merge_labels.cu) ConfigureTest( diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py new file mode 100644 index 0000000000..4e719e272c --- /dev/null +++ b/cpp/test/ext_headers/00_generate.py @@ -0,0 +1,75 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +copyright_notice = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +""" + +ext_headers = [ + "raft/neighbors/brute_force-ext.cuh", + "raft/distance/distance-ext.cuh", + "raft/distance/detail/pairwise_matrix/dispatch-ext.cuh", + "raft/matrix/detail/select_k-ext.cuh", + "raft/neighbors/ball_cover-ext.cuh", + "raft/spatial/knn/detail/fused_l2_knn-ext.cuh", + "raft/distance/fused_l2_nn-ext.cuh", + "raft/neighbors/ivf_pq-ext.cuh", + "raft/neighbors/ivf_flat-ext.cuh", + "raft/neighbors/refine-ext.cuh", + "raft/neighbors/detail/ivf_flat_search-ext.cuh", + "raft/neighbors/detail/selection_faiss-ext.cuh", + "raft/linalg/detail/coalesced_reduction-ext.cuh", + "raft/spatial/knn/detail/ball_cover/registers-ext.cuh", +] + +for ext_header in ext_headers: + header = ext_header.replace("-ext", "") + + path = ( + header + .replace("/", "_") + .replace(".cuh", ".cu") + .replace(".hpp", ".cpp") + ) + + with open(path, "w") as f: + f.write(copyright_notice) + f.write(f"#include <{header}>\n") + + # For in CMakeLists.txt + print(f"test/ext_headers/{path}") diff --git a/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu b/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu new file mode 100644 index 0000000000..02e4c8e331 --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_detail_pairwise_matrix_dispatch.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_distance_distance.cu b/cpp/test/ext_headers/raft_distance_distance.cu new file mode 100644 index 0000000000..458d6385ed --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_distance.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu b/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu new file mode 100644 index 0000000000..23ab58a67b --- /dev/null +++ b/cpp/test/ext_headers/raft_distance_fused_l2_nn.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu b/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu new file mode 100644 index 0000000000..7f94824287 --- /dev/null +++ b/cpp/test/ext_headers/raft_linalg_detail_coalesced_reduction.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_matrix_detail_select_k.cu b/cpp/test/ext_headers/raft_matrix_detail_select_k.cu new file mode 100644 index 0000000000..adb10f5bbb --- /dev/null +++ b/cpp/test/ext_headers/raft_matrix_detail_select_k.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ball_cover.cu b/cpp/test/ext_headers/raft_neighbors_ball_cover.cu new file mode 100644 index 0000000000..8aaabe1872 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ball_cover.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_brute_force.cu b/cpp/test/ext_headers/raft_neighbors_brute_force.cu new file mode 100644 index 0000000000..2c37799ae6 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_brute_force.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu new file mode 100644 index 0000000000..a6274c1c80 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu b/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu new file mode 100644 index 0000000000..f8bd21e86f --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_selection_faiss.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu b/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu new file mode 100644 index 0000000000..ab38e4c02c --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ivf_flat.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu b/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu new file mode 100644 index 0000000000..43a66bde18 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_ivf_pq.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_refine.cu b/cpp/test/ext_headers/raft_neighbors_refine.cu new file mode 100644 index 0000000000..6152f83aab --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_refine.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu b/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu new file mode 100644 index 0000000000..39320a40c0 --- /dev/null +++ b/cpp/test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu b/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu new file mode 100644 index 0000000000..f884d1b062 --- /dev/null +++ b/cpp/test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include From 7d1aa77858620cc37378b4a0a0c0a6d74835eff1 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 19:42:12 +0200 Subject: [PATCH 27/32] Split ivf-pq and ivf-flat search The compute_similarity and interleaved_scan kernel are quite expensive to compile. Splitting the headers in this commit. --- cpp/CMakeLists.txt | 7 + .../detail/ivf_flat_interleaved_scan-ext.cuh | 65 + .../detail/ivf_flat_interleaved_scan-inl.cuh | 1076 +++++++++++++++++ .../detail/ivf_flat_interleaved_scan.cuh | 25 + .../neighbors/detail/ivf_flat_search-ext.cuh | 36 - .../neighbors/detail/ivf_flat_search-inl.cuh | 1075 +--------------- .../detail/ivf_pq_compute_similarity-ext.cuh | 183 +++ .../detail/ivf_pq_compute_similarity-inl.cuh | 845 +++++++++++++ .../detail/ivf_pq_compute_similarity.cuh | 25 + .../detail/ivf_pq_dummy_block_sort.cuh | 39 + .../raft/neighbors/detail/ivf_pq_fp_8bit.cuh | 113 ++ .../raft/neighbors/detail/ivf_pq_search.cuh | 907 +------------- cpp/include/raft/neighbors/detail/refine.cuh | 1 + cpp/include/raft/neighbors/ivf_flat-inl.cuh | 82 +- cpp/include/raft/neighbors/ivf_flat_types.hpp | 10 +- ...at_interleaved_scan_float_float_int64_t.cu | 2 +- ...interleaved_scan_int8_t_int32_t_int64_t.cu | 2 +- ...terleaved_scan_uint8_t_uint32_t_int64_t.cu | 2 +- .../ivf_pq_compute_similarity_00_generate.py | 107 ++ .../ivf_pq_compute_similarity_float_float.cu | 73 ++ ...f_pq_compute_similarity_float_fp8_false.cu | 74 ++ ...vf_pq_compute_similarity_float_fp8_true.cu | 74 ++ .../ivf_pq_compute_similarity_float_half.cu | 73 ++ ...vf_pq_compute_similarity_half_fp8_false.cu | 74 ++ ...ivf_pq_compute_similarity_half_fp8_true.cu | 74 ++ .../ivf_pq_compute_similarity_half_half.cu | 73 ++ cpp/test/ext_headers/00_generate.py | 2 + ...ghbors_detail_ivf_flat_interleaved_scan.cu | 27 + ...ghbors_detail_ivf_pq_compute_similarity.cu | 27 + 29 files changed, 3153 insertions(+), 2020 deletions(-) create mode 100644 cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh create mode 100644 cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu create mode 100644 cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu create mode 100644 cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 55ad024a56..dfed15f6a4 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -309,6 +309,13 @@ if(RAFT_COMPILE_LIBRARY) src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu src/neighbors/detail/ivf_flat_search.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu + src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu + src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu src/neighbors/detail/selection_faiss_int32_t_float.cu src/neighbors/detail/selection_faiss_int_double.cu src/neighbors/detail/selection_faiss_long_float.cu diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh new file mode 100644 index 0000000000..46f72c4005 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // uintX_t +#include // raft::neighbors::ivf_flat::index +#include // RAFT_EXPLICIT +#include // rmm:cuda_stream_view + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_flat::detail { + +template +void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_flat::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ + extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ + const raft::neighbors::ivf_flat::index& index, \ + const T* queries, \ + const uint32_t* coarse_query_results, \ + const uint32_t n_queries, \ + const raft::distance::DistanceType metric, \ + const uint32_t n_probes, \ + const uint32_t k, \ + const bool select_min, \ + IdxT* neighbors, \ + float* distances, \ + uint32_t& grid_dim_x, \ + rmm::cuda_stream_view stream) + +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); +instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); + +#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh new file mode 100644 index 0000000000..4eed2aa453 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan-inl.cuh @@ -0,0 +1,1076 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // RAFT_LOG_TRACE +#include +#include +#include +#include +#include +#include // RAFT_CUDA_TRY +#include +#include +#include +#include +#include + +namespace raft::neighbors::ivf_flat::detail { + +using namespace raft::spatial::knn::detail; // NOLINT + +constexpr int kThreadsPerBlock = 128; + +/** + * @brief Copy `n` elements per block from one place to another. + * + * @param[out] out target pointer (unique per block) + * @param[in] in source pointer + * @param n number of elements to copy + */ +template +__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) +{ + constexpr int VecElems = VecBytes / sizeof(T); // NOLINT + using align_bytes = Pow2<(size_t)VecBytes>; + if constexpr (VecElems > 1) { + using align_elems = Pow2; + if (!align_bytes::areSameAlignOffsets(out, in)) { + return copy_vectorized<(VecBytes >> 1), T>(out, in, n); + } + { // process unaligned head + uint32_t head = align_bytes::roundUp(in) - in; + if (head > 0) { + copy_vectorized(out, in, head); + n -= head; + in += head; + out += head; + } + } + { // process main part vectorized + using vec_t = typename IOType::Type; + copy_vectorized( + reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); + } + { // process unaligned tail + uint32_t tail = align_elems::mod(n); + if (tail > 0) { + n -= tail; + copy_vectorized(out + n, in + n, tail); + } + } + } + if constexpr (VecElems <= 1) { + for (int i = threadIdx.x; i < n; i += blockDim.x) { + out[i] = in[i]; + } + } +} + +/** + * @brief Load a part of a vector from the index and from query, compute the (part of the) distance + * between them, and aggregate it using the provided Lambda; one structure per thread, per query, + * and per index item. + * + * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) + * @tparam Lambda computing the part of the distance for one dimension and aggregating it: + * void (AccT& acc, AccT x, AccT y) + * @tparam Veclen size of the vectorized load + * @tparam T type of the data in the query and the index + * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit + * values) + */ +template +struct loadAndComputeDist { + Lambda compute_dist; + AccT& dist; + + __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in shared memory. + * Every thread here processes exactly kUnroll * Veclen elements independently of others. + */ + template + __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, + const T* query_shared, + IdxT loadIndex, + IdxT shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); + T queryRegs[Veclen]; + lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version assumes the query is stored in the global memory and is different for every + * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into + * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). + */ + template + __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, + const T* query, + IdxT baseLoadIndex, + const int lane_id) + { + T queryReg = query[baseLoadIndex + lane_id]; + constexpr int stride = kUnroll * Veclen; + constexpr int totalIter = WarpSize / stride; + constexpr int gmemStride = stride * kIndexGroupSize; +#pragma unroll + for (int i = 0; i < totalIter; ++i, data += gmemStride) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + T encV[Veclen]; + ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); + const int d = (i * kUnroll + j) * Veclen; +#pragma unroll + for (int k = 0; k < Veclen; ++k) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); + } + } + } + } + + /** + * Load parts of vectors from the index and query and accumulates the partial distance. + * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. + */ + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) + { + const int loadDim = dimBlocks + lane_id; + T queryReg = loadDim < dim ? query[loadDim] : 0; + const int loadDataIdx = lane_id * Veclen; + for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { + T enc[Veclen]; + ldg(enc, data + loadDataIdx); +#pragma unroll + for (int k = 0; k < Veclen; k++) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); + } + } + } +}; + +// This handles uint8_t 8, 16 Veclens +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + loadIndex = loadIndex * veclen_int; +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); + uint32_t queryRegs[veclen_int]; + lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * uint8_veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen_int = uint8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; + d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { + uint32_t enc[veclen_int]; + ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); + compute_dist(dist, q, enc[k]); + } + } + } +}; + +// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while +// using above common template of int2/int4 +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 4; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 4; + const int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = reinterpret_cast(data)[lane_id]; + uint32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + uint32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, + const uint8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[loadIndex + j * kIndexGroupSize]; + uint32_t queryRegs = query_shared[shmemIndex + j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, + const uint8_t* query, + int baseLoadIndex, + const int lane_id) + { + uint32_t queryReg = query[baseLoadIndex + lane_id]; + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + uint32_t encV = data[lane_id + j * kIndexGroupSize]; + uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, + const uint8_t* query, + const int lane_id, + const int dim, + const int dimBlocks) + { + constexpr int veclen = 1; + int loadDim = dimBlocks + lane_id; + uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + uint32_t enc = data[lane_id]; + uint32_t q = shfl(queryReg, d, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +// This device function is for int8 veclens 4, 8 and 16 +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); + int32_t queryRegs[veclen_int]; + lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + compute_dist(dist, queryRegs[k], encV[k]); + } + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int + + int32_t queryReg = + (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int stride = kUnroll * int8_veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV[veclen_int]; + ldg(encV, + reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); + const int d = (i * kUnroll + j) * veclen_int; +#pragma unroll + for (int k = 0; k < veclen_int; ++k) { + int32_t q = shfl(queryReg, d + k, WarpSize); + compute_dist(dist, q, encV[k]); + } + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen_int = int8_veclen / 4; + const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { + int32_t enc[veclen_int]; + ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); +#pragma unroll + for (int k = 0; k < veclen_int; k++) { + int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; + compute_dist(dist, q, enc[k]); + } + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; + int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; + compute_dist(dist, queryRegs, encV); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + int32_t queryReg = + (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; + constexpr int veclen = 2; + constexpr int stride = kUnroll * veclen; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; + int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); + compute_dist(dist, q, encV); + } + } + } + + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 2; + int loadDim = dimBlocks + lane_id * veclen; + int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; + int32_t q = shfl(queryReg, d / veclen, WarpSize); + compute_dist(dist, q, enc); + } + } +}; + +template +struct loadAndComputeDist { + Lambda compute_dist; + int32_t& dist; + __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) + : dist(dist), compute_dist(op) + { + } + + __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, + const int8_t* query_shared, + int loadIndex, + int shmemIndex) + { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); + } + } + + __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, + const int8_t* query, + int baseLoadIndex, + const int lane_id) + { + constexpr int veclen = 1; + constexpr int stride = kUnroll * veclen; + int32_t queryReg = query[baseLoadIndex + lane_id]; + +#pragma unroll + for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { +#pragma unroll + for (int j = 0; j < kUnroll; ++j) { + compute_dist( + dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); + } + } + } + __device__ __forceinline__ void runLoadShflAndComputeRemainder( + const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) + { + constexpr int veclen = 1; + const int loadDim = dimBlocks + lane_id; + int32_t queryReg = loadDim < dim ? query[loadDim] : 0; + for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { + compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); + } + } +}; + +/** + * Scan clusters for nearest neighbors of the query vectors. + * See `ivfflat_interleaved_scan` for more information. + * + * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. + * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is + * calculated, and the top-k nearest neighbors are selected. + * + * @param compute_dist distance function + * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a + * block; this number must be a multiple of `WarpSize * Veclen`. + * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] + * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] + * @param[in] list_indices index.indices + * @param[in] list_data index.data + * @param[in] list_sizes index.list_sizes + * @param[in] list_offsets index.list_offsets + * @param n_probes + * @param k + * @param dim + * @param[out] neighbors + * @param[out] distances + */ +template +__global__ void __launch_bounds__(kThreadsPerBlock) + interleaved_scan_kernel(Lambda compute_dist, + PostLambda post_process, + const uint32_t query_smem_elems, + const T* query, + const uint32_t* coarse_index, + const IdxT* const* list_indices_ptrs, + const T* const* list_data_ptrs, + const uint32_t* list_sizes, + const uint32_t n_probes, + const uint32_t k, + const uint32_t dim, + IdxT* neighbors, + float* distances) +{ + extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; + // Using shared memory for the (part of the) query; + // This allows to save on global memory bandwidth when reading index and query + // data at the same time. + // Its size is `query_smem_elems`. + T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); + // Make the query input and output point to this block's shared query + { + const int query_id = blockIdx.y; + query += query_id * dim; + neighbors += query_id * k * gridDim.x + blockIdx.x * k; + distances += query_id * k * gridDim.x + blockIdx.x * k; + coarse_index += query_id * n_probes; + } + + // Copy a part of the query into shared memory for faster processing + copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); + __syncthreads(); + + using block_sort_t = matrix::detail::select::warpsort::block_sort< + matrix::detail::select::warpsort::warp_sort_filtered, + Capacity, + Ascending, + float, + IdxT>; + block_sort_t queue(k); + + { + using align_warp = Pow2; + const int lane_id = align_warp::mod(threadIdx.x); + + // How many full warps needed to compute the distance (without remainder) + const uint32_t full_warps_along_dim = align_warp::roundDown(dim); + + const uint32_t shm_assisted_dim = + (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; + + // Every CUDA block scans one cluster at a time. + for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { + const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) + + // The number of vectors in each cluster(list); [nlist] + const uint32_t list_length = list_sizes[list_id]; + + // The number of interleaved groups to be processed + const uint32_t num_groups = + align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 + + constexpr int kUnroll = WarpSize / Veclen; + constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; + // Every warp reads WarpSize vectors and computes the distances to them. + // Then, the distances and corresponding ids are distributed among the threads, + // and each thread adds one (id, dist) pair to the filtering queue. + for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; + group_id += kNumWarps) { + AccT dist = 0; + // This is where this warp begins reading data (start position of an interleaved group) + const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; + + // This is the vector a given lane/thread handles + const uint32_t vec_id = group_id * WarpSize + lane_id; + const bool valid = vec_id < list_length; + + // Process first shm_assisted_dim dimensions (always using shared memory) + if (valid) { + loadAndComputeDist lc(dist, + compute_dist); + for (int pos = 0; pos < shm_assisted_dim; + pos += WarpSize, data += kIndexGroupSize * WarpSize) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + + if (dim > query_smem_elems) { + // The default path - using shfl ops - for dimensions beyond query_smem_elems + loadAndComputeDist lc(dist, + compute_dist); + for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { + lc.runLoadShflAndCompute(data, query, pos, lane_id); + } + lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); + } else { + // when shm_assisted_dim == full_warps_along_dim < dim + if (valid) { + loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); + for (int pos = full_warps_along_dim; pos < dim; + pos += Veclen, data += kIndexGroupSize * Veclen) { + lc.runLoadShmemCompute(data, query_shared, lane_id, pos); + } + } + } + + // Enqueue one element per thread + const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; + const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; + queue.add(val, idx); + } + } + } + + // finalize and store selected neighbours + __syncthreads(); + queue.done(interleaved_scan_kernel_smem); + queue.store(distances, neighbors, post_process); +} + +/** + * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size + */ +template +uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) +{ + int dev_id; + RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); + int num_sms; + RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); + int num_blocks_per_sm = 0; + RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); + + size_t min_grid_size = num_sms * num_blocks_per_sm; + size_t min_grid_x = ceildiv(min_grid_size, numQueries); + return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); +} + +template +void launch_kernel(Lambda lambda, + PostLambda post_process, + const index& index, + const T* queries, + const uint32_t* coarse_index, + const uint32_t num_queries, + const uint32_t n_probes, + const uint32_t k, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + RAFT_EXPECTS(Veclen == index.veclen(), + "Configured Veclen does not match the index interleaving pattern."); + constexpr auto kKernel = + interleaved_scan_kernel; + const int max_query_smem = 16384; + int query_smem_elems = + std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); + int smem_size = query_smem_elems * sizeof(T); + constexpr int kSubwarpSize = std::min(Capacity, WarpSize); + auto block_merge_mem = + raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( + kThreadsPerBlock / kSubwarpSize, k); + smem_size += std::max(smem_size, block_merge_mem); + + // power-of-two less than cuda limit (for better addr alignment) + constexpr uint32_t kMaxGridY = 32768; + + if (grid_dim_x == 0) { + grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); + return; + } + + for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { + uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); + dim3 grid_dim(grid_dim_x, grid_dim_y, 1); + dim3 block_dim(kThreadsPerBlock); + RAFT_LOG_TRACE( + "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " + "smem_size = %d", + grid_dim.x, + grid_dim.y, + block_dim.x, + n_probes, + smem_size); + kKernel<<>>(lambda, + post_process, + query_smem_elems, + queries, + coarse_index, + index.inds_ptrs().data_handle(), + index.data_ptrs().data_handle(), + index.list_sizes().data_handle(), + n_probes, + k, + index.dim(), + neighbors, + distances); + queries += grid_dim_y * index.dim(); + neighbors += grid_dim_y * grid_dim_x * k; + distances += grid_dim_y * grid_dim_x * k; + } +} + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + const auto diff = x - y; + acc += diff * diff; + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) + { + if constexpr (Veclen > 1) { + const auto diff = __vabsdiffu4(x, y); + acc = dp4a(diff, diff, acc); + } else { + const auto diff = __usad(x, y, 0u); + acc += diff * diff; + } + } +}; + +template +struct euclidean_dist { + __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) + { + if constexpr (Veclen > 1) { + // Note that we enforce here that the unsigned version of dp4a is used, because the difference + // between two int8 numbers can be greater than 127 and therefore represented as a negative + // number in int8. Casting from int8 to int32 would yield incorrect results, while casting + // from uint8 to uint32 is correct. + const auto diff = __vabsdiffs4(x, y); + acc = dp4a(diff, diff, static_cast(acc)); + } else { + const auto diff = x - y; + acc += diff * diff; + } + } +}; + +template +struct inner_prod_dist { + __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) + { + if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { + acc = dp4a(x, y, acc); + } else { + acc += x * y; + } + } +}; + +/** Select the distance computation function and forward the rest of the arguments. */ +template +void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) +{ + switch (metric) { + case raft::distance::DistanceType::L2Expanded: + case raft::distance::DistanceType::L2Unexpanded: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::L2SqrtExpanded: + case raft::distance::DistanceType::L2SqrtUnexpanded: + return launch_kernel, + raft::sqrt_op>({}, {}, std::forward(args)...); + case raft::distance::DistanceType::InnerProduct: + return launch_kernel, + raft::identity_op>({}, {}, std::forward(args)...); + // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. + default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + } +} + +/** + * Lift the `capacity` and `veclen` parameters to the template level, + * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. + */ +template (1, 16 / sizeof(T))> +struct select_interleaved_scan_kernel { + /** + * Recursively reduce the `Capacity` and `Veclen` parameters until they match the + * corresponding runtime arguments. + * By default, this recursive process starts with maximum possible values of the + * two parameters and ends with both values equal to 1. + */ + template + static inline void run(int capacity, int veclen, bool select_min, Args&&... args) + { + if constexpr (Capacity > 1) { + if (capacity * 2 <= Capacity) { + return select_interleaved_scan_kernel::run( + capacity, veclen, select_min, std::forward(args)...); + } + } + if constexpr (Veclen > 1) { + if (veclen % Veclen != 0) { + return select_interleaved_scan_kernel::run( + capacity, 1, select_min, std::forward(args)...); + } + } + // NB: this is the limitation of the warpsort structures that use a huge number of + // registers (used in the main kernel here). + RAFT_EXPECTS(capacity == Capacity, + "Capacity must be power-of-two not bigger than the maximum allowed size " + "matrix::detail::select::warpsort::kMaxCapacity (%d).", + matrix::detail::select::warpsort::kMaxCapacity); + RAFT_EXPECTS( + veclen == Veclen, + "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); + if (select_min) { + launch_with_fixed_consts(std::forward(args)...); + } else { + launch_with_fixed_consts(std::forward(args)...); + } + } +}; + +/** + * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. + * + * @tparam T value type + * @tparam AccT accumulated type + * @tparam IdxT type of the indices + * + * @param index previously built ivf-flat index + * @param[in] queries device pointer to the query vectors [batch_size, dim] + * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] + * @param n_queries batch size + * @param metric type of the measured distance + * @param n_probes number of nearest clusters to query + * @param k number of nearest neighbors. + * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. + * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given + * metric. + * @param[out] neighbors device pointer to the result indices for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[out] distances device pointer to the result distances for each query and cluster + * [batch_size, grid_dim_x, k] + * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; + * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) + * @param stream + */ +template +void ivfflat_interleaved_scan(const index& index, + const T* queries, + const uint32_t* coarse_query_results, + const uint32_t n_queries, + const raft::distance::DistanceType metric, + const uint32_t n_probes, + const uint32_t k, + const bool select_min, + IdxT* neighbors, + float* distances, + uint32_t& grid_dim_x, + rmm::cuda_stream_view stream) +{ + const int capacity = bound_by_power_of_two(k); + select_interleaved_scan_kernel::run(capacity, + index.veclen(), + select_min, + metric, + index, + queries, + coarse_query_results, + n_queries, + n_probes, + k, + neighbors, + distances, + grid_dim_x, + stream); +} + +} // namespace raft::neighbors::ivf_flat::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh new file mode 100644 index 0000000000..63f341dd9a --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_flat_interleaved_scan.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_flat_interleaved_scan-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_flat_interleaved_scan-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh index be71c9eb11..3bb3a4308d 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-ext.cuh @@ -19,7 +19,6 @@ #include // uintX_t #include // raft::neighbors::ivf_flat::index #include // RAFT_EXPLICIT -#include // rmm:cuda_stream_view #ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -36,20 +35,6 @@ void search(raft::device_resources const& handle, float* distances, rmm::mr::device_memory_resource* mr = nullptr); -template -void ivfflat_interleaved_scan(const raft::neighbors::ivf_flat::index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const bool select_min, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) RAFT_EXPLICIT; - } // namespace raft::neighbors::ivf_flat::detail #endif // RAFT_EXPLICIT_INSTANTIATE_ONLY @@ -71,24 +56,3 @@ instantiate_raft_neighbors_ivf_flat_detail_search(int8_t, int64_t); instantiate_raft_neighbors_ivf_flat_detail_search(uint8_t, int64_t); #undef instantiate_raft_neighbors_ivf_flat_detail_search - -#define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ - extern template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ - const raft::neighbors::ivf_flat::index& index, \ - const T* queries, \ - const uint32_t* coarse_query_results, \ - const uint32_t n_queries, \ - const raft::distance::DistanceType metric, \ - const uint32_t n_probes, \ - const uint32_t k, \ - const bool select_min, \ - IdxT* neighbors, \ - float* distances, \ - uint32_t& grid_dim_x, \ - rmm::cuda_stream_view stream) - -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(float, float, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(int8_t, int32_t, int64_t); -instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(uint8_t, uint32_t, int64_t); - -#undef instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh index 526ee2b7b0..89a4597acf 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_search-inl.cuh @@ -16,1076 +16,25 @@ #pragma once -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include +#include // raft::device_resources +#include // RAFT_LOG_TRACE +#include // is_min_close, DistanceType +#include // raft::linalg::gemm +#include // raft::linalg::norm +#include // raft::linalg::unary_op +#include // matrix::detail::select_k +#include // interleaved_scan +#include // raft::neighbors::ivf_flat::index +#include // utils::mapping +#include // rmm::device_memory_resource namespace raft::neighbors::ivf_flat::detail { using namespace raft::spatial::knn::detail; // NOLINT -constexpr int kThreadsPerBlock = 128; - -/** - * @brief Copy `n` elements per block from one place to another. - * - * @param[out] out target pointer (unique per block) - * @param[in] in source pointer - * @param n number of elements to copy - */ -template -__device__ inline void copy_vectorized(T* out, const T* in, uint32_t n) -{ - constexpr int VecElems = VecBytes / sizeof(T); // NOLINT - using align_bytes = Pow2<(size_t)VecBytes>; - if constexpr (VecElems > 1) { - using align_elems = Pow2; - if (!align_bytes::areSameAlignOffsets(out, in)) { - return copy_vectorized<(VecBytes >> 1), T>(out, in, n); - } - { // process unaligned head - uint32_t head = align_bytes::roundUp(in) - in; - if (head > 0) { - copy_vectorized(out, in, head); - n -= head; - in += head; - out += head; - } - } - { // process main part vectorized - using vec_t = typename IOType::Type; - copy_vectorized( - reinterpret_cast(out), reinterpret_cast(in), align_elems::div(n)); - } - { // process unaligned tail - uint32_t tail = align_elems::mod(n); - if (tail > 0) { - n -= tail; - copy_vectorized(out + n, in + n, tail); - } - } - } - if constexpr (VecElems <= 1) { - for (int i = threadIdx.x; i < n; i += blockDim.x) { - out[i] = in[i]; - } - } -} - -/** - * @brief Load a part of a vector from the index and from query, compute the (part of the) distance - * between them, and aggregate it using the provided Lambda; one structure per thread, per query, - * and per index item. - * - * @tparam kUnroll elements per loop (normally, kUnroll = WarpSize / Veclen) - * @tparam Lambda computing the part of the distance for one dimension and aggregating it: - * void (AccT& acc, AccT x, AccT y) - * @tparam Veclen size of the vectorized load - * @tparam T type of the data in the query and the index - * @tparam AccT type of the accumulated value (an optimization for 8bit values to be loaded as 32bit - * values) - */ -template -struct loadAndComputeDist { - Lambda compute_dist; - AccT& dist; - - __device__ __forceinline__ loadAndComputeDist(AccT& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in shared memory. - * Every thread here processes exactly kUnroll * Veclen elements independently of others. - */ - template - __device__ __forceinline__ void runLoadShmemCompute(const T* const& data, - const T* query_shared, - IdxT loadIndex, - IdxT shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (loadIndex + j * kIndexGroupSize) * Veclen); - T queryRegs[Veclen]; - lds(queryRegs, &query_shared[shmemIndex + j * Veclen]); -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version assumes the query is stored in the global memory and is different for every - * thread. One warp loads exactly WarpSize query elements at once and then reshuffles them into - * corresponding threads (`WarpSize / (kUnroll * Veclen)` elements per thread at once). - */ - template - __device__ __forceinline__ void runLoadShflAndCompute(const T*& data, - const T* query, - IdxT baseLoadIndex, - const int lane_id) - { - T queryReg = query[baseLoadIndex + lane_id]; - constexpr int stride = kUnroll * Veclen; - constexpr int totalIter = WarpSize / stride; - constexpr int gmemStride = stride * kIndexGroupSize; -#pragma unroll - for (int i = 0; i < totalIter; ++i, data += gmemStride) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - T encV[Veclen]; - ldg(encV, data + (lane_id + j * kIndexGroupSize) * Veclen); - const int d = (i * kUnroll + j) * Veclen; -#pragma unroll - for (int k = 0; k < Veclen; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - /** - * Load parts of vectors from the index and query and accumulates the partial distance. - * This version augments `runLoadShflAndCompute` when `dim` is not a multiple of `WarpSize`. - */ - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const T*& data, const T* query, const int lane_id, const int dim, const int dimBlocks) - { - const int loadDim = dimBlocks + lane_id; - T queryReg = loadDim < dim ? query[loadDim] : 0; - const int loadDataIdx = lane_id * Veclen; - for (int d = 0; d < dim - dimBlocks; d += Veclen, data += kIndexGroupSize * Veclen) { - T enc[Veclen]; - ldg(enc, data + loadDataIdx); -#pragma unroll - for (int k = 0; k < Veclen; k++) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), enc[k]); - } - } - } -}; - -// This handles uint8_t 8, 16 Veclens -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - loadIndex = loadIndex * veclen_int; -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + loadIndex + j * kIndexGroupSize * veclen_int); - uint32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = uint8_veclen / 4; // converting uint8_t veclens to int - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * uint8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - compute_dist(dist, shfl(queryReg, d + k, WarpSize), encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen_int = uint8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; - d += uint8_veclen, data += kIndexGroupSize * uint8_veclen) { - uint32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - uint32_t q = shfl(queryReg, (d / 4) + k, WarpSize); - compute_dist(dist, q, enc[k]); - } - } - } -}; - -// Keep this specialized uint8 Veclen = 4, because compiler is generating suboptimal code while -// using above common template of int2/int4 -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 4; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 4; - const int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query)[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - uint32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = reinterpret_cast(data)[lane_id]; - uint32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - uint32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(uint32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const uint8_t* const& data, - const uint8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[loadIndex + j * kIndexGroupSize]; - uint32_t queryRegs = query_shared[shmemIndex + j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const uint8_t*& data, - const uint8_t* query, - int baseLoadIndex, - const int lane_id) - { - uint32_t queryReg = query[baseLoadIndex + lane_id]; - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - uint32_t encV = data[lane_id + j * kIndexGroupSize]; - uint32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder(const uint8_t*& data, - const uint8_t* query, - const int lane_id, - const int dim, - const int dimBlocks) - { - constexpr int veclen = 1; - int loadDim = dimBlocks + lane_id; - uint32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - uint32_t enc = data[lane_id]; - uint32_t q = shfl(queryReg, d, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -// This device function is for int8 veclens 4, 8 and 16 -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (loadIndex + j * kIndexGroupSize) * veclen_int); - int32_t queryRegs[veclen_int]; - lds(queryRegs, reinterpret_cast(query_shared + shmemIndex) + j * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - compute_dist(dist, queryRegs[k], encV[k]); - } - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen_int = int8_veclen / 4; // converting int8_t veclens to int - - int32_t queryReg = - (lane_id < 8) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int stride = kUnroll * int8_veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV[veclen_int]; - ldg(encV, - reinterpret_cast(data) + (lane_id + j * kIndexGroupSize) * veclen_int); - const int d = (i * kUnroll + j) * veclen_int; -#pragma unroll - for (int k = 0; k < veclen_int; ++k) { - int32_t q = shfl(queryReg, d + k, WarpSize); - compute_dist(dist, q, encV[k]); - } - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen_int = int8_veclen / 4; - const int loadDim = dimBlocks + lane_id * 4; // Here 4 is for 1 - int; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += int8_veclen, data += kIndexGroupSize * int8_veclen) { - int32_t enc[veclen_int]; - ldg(enc, reinterpret_cast(data) + lane_id * veclen_int); -#pragma unroll - for (int k = 0; k < veclen_int; k++) { - int32_t q = shfl(queryReg, (d / 4) + k, WarpSize); // Here 4 is for 1 - int; - compute_dist(dist, q, enc[k]); - } - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[loadIndex + j * kIndexGroupSize]; - int32_t queryRegs = reinterpret_cast(query_shared + shmemIndex)[j]; - compute_dist(dist, queryRegs, encV); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - int32_t queryReg = - (lane_id < 16) ? reinterpret_cast(query + baseLoadIndex)[lane_id] : 0; - constexpr int veclen = 2; - constexpr int stride = kUnroll * veclen; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - int32_t encV = reinterpret_cast(data)[lane_id + j * kIndexGroupSize]; - int32_t q = shfl(queryReg, i * kUnroll + j, WarpSize); - compute_dist(dist, q, encV); - } - } - } - - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 2; - int loadDim = dimBlocks + lane_id * veclen; - int32_t queryReg = loadDim < dim ? reinterpret_cast(query + loadDim)[0] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - int32_t enc = reinterpret_cast(data + lane_id * veclen)[0]; - int32_t q = shfl(queryReg, d / veclen, WarpSize); - compute_dist(dist, q, enc); - } - } -}; - -template -struct loadAndComputeDist { - Lambda compute_dist; - int32_t& dist; - __device__ __forceinline__ loadAndComputeDist(int32_t& dist, Lambda op) - : dist(dist), compute_dist(op) - { - } - - __device__ __forceinline__ void runLoadShmemCompute(const int8_t* const& data, - const int8_t* query_shared, - int loadIndex, - int shmemIndex) - { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist(dist, query_shared[shmemIndex + j], data[loadIndex + j * kIndexGroupSize]); - } - } - - __device__ __forceinline__ void runLoadShflAndCompute(const int8_t*& data, - const int8_t* query, - int baseLoadIndex, - const int lane_id) - { - constexpr int veclen = 1; - constexpr int stride = kUnroll * veclen; - int32_t queryReg = query[baseLoadIndex + lane_id]; - -#pragma unroll - for (int i = 0; i < WarpSize / stride; ++i, data += stride * kIndexGroupSize) { -#pragma unroll - for (int j = 0; j < kUnroll; ++j) { - compute_dist( - dist, shfl(queryReg, i * kUnroll + j, WarpSize), data[lane_id + j * kIndexGroupSize]); - } - } - } - __device__ __forceinline__ void runLoadShflAndComputeRemainder( - const int8_t*& data, const int8_t* query, const int lane_id, const int dim, const int dimBlocks) - { - constexpr int veclen = 1; - const int loadDim = dimBlocks + lane_id; - int32_t queryReg = loadDim < dim ? query[loadDim] : 0; - for (int d = 0; d < dim - dimBlocks; d += veclen, data += kIndexGroupSize * veclen) { - compute_dist(dist, shfl(queryReg, d, WarpSize), data[lane_id]); - } - } -}; - -/** - * Scan clusters for nearest neighbors of the query vectors. - * See `ivfflat_interleaved_scan` for more information. - * - * The clusters are stored in the interleaved index format described in ivf_flat_types.hpp. - * For each query vector, a set of clusters is probed: the distance to each vector in the cluster is - * calculated, and the top-k nearest neighbors are selected. - * - * @param compute_dist distance function - * @param query_smem_elems number of dimensions of the query vector to fit in a shared memory of a - * block; this number must be a multiple of `WarpSize * Veclen`. - * @param[in] query a pointer to all queries in a row-major contiguous format [gridDim.y, dim] - * @param[in] coarse_index a pointer to the cluster indices to search through [n_probes] - * @param[in] list_indices index.indices - * @param[in] list_data index.data - * @param[in] list_sizes index.list_sizes - * @param[in] list_offsets index.list_offsets - * @param n_probes - * @param k - * @param dim - * @param[out] neighbors - * @param[out] distances - */ -template -__global__ void __launch_bounds__(kThreadsPerBlock) - interleaved_scan_kernel(Lambda compute_dist, - PostLambda post_process, - const uint32_t query_smem_elems, - const T* query, - const uint32_t* coarse_index, - const IdxT* const* list_indices_ptrs, - const T* const* list_data_ptrs, - const uint32_t* list_sizes, - const uint32_t n_probes, - const uint32_t k, - const uint32_t dim, - IdxT* neighbors, - float* distances) -{ - extern __shared__ __align__(256) uint8_t interleaved_scan_kernel_smem[]; - // Using shared memory for the (part of the) query; - // This allows to save on global memory bandwidth when reading index and query - // data at the same time. - // Its size is `query_smem_elems`. - T* query_shared = reinterpret_cast(interleaved_scan_kernel_smem); - // Make the query input and output point to this block's shared query - { - const int query_id = blockIdx.y; - query += query_id * dim; - neighbors += query_id * k * gridDim.x + blockIdx.x * k; - distances += query_id * k * gridDim.x + blockIdx.x * k; - coarse_index += query_id * n_probes; - } - - // Copy a part of the query into shared memory for faster processing - copy_vectorized(query_shared, query, std::min(dim, query_smem_elems)); - __syncthreads(); - - using block_sort_t = matrix::detail::select::warpsort::block_sort< - matrix::detail::select::warpsort::warp_sort_filtered, - Capacity, - Ascending, - float, - IdxT>; - block_sort_t queue(k); - - { - using align_warp = Pow2; - const int lane_id = align_warp::mod(threadIdx.x); - - // How many full warps needed to compute the distance (without remainder) - const uint32_t full_warps_along_dim = align_warp::roundDown(dim); - - const uint32_t shm_assisted_dim = - (dim > query_smem_elems) ? query_smem_elems : full_warps_along_dim; - - // Every CUDA block scans one cluster at a time. - for (int probe_id = blockIdx.x; probe_id < n_probes; probe_id += gridDim.x) { - const uint32_t list_id = coarse_index[probe_id]; // The id of cluster(list) - - // The number of vectors in each cluster(list); [nlist] - const uint32_t list_length = list_sizes[list_id]; - - // The number of interleaved groups to be processed - const uint32_t num_groups = - align_warp::div(list_length + align_warp::Mask); // ceildiv by power of 2 - - constexpr int kUnroll = WarpSize / Veclen; - constexpr uint32_t kNumWarps = kThreadsPerBlock / WarpSize; - // Every warp reads WarpSize vectors and computes the distances to them. - // Then, the distances and corresponding ids are distributed among the threads, - // and each thread adds one (id, dist) pair to the filtering queue. - for (uint32_t group_id = align_warp::div(threadIdx.x); group_id < num_groups; - group_id += kNumWarps) { - AccT dist = 0; - // This is where this warp begins reading data (start position of an interleaved group) - const T* data = list_data_ptrs[list_id] + (group_id * kIndexGroupSize) * dim; - - // This is the vector a given lane/thread handles - const uint32_t vec_id = group_id * WarpSize + lane_id; - const bool valid = vec_id < list_length; - - // Process first shm_assisted_dim dimensions (always using shared memory) - if (valid) { - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = 0; pos < shm_assisted_dim; - pos += WarpSize, data += kIndexGroupSize * WarpSize) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - - if (dim > query_smem_elems) { - // The default path - using shfl ops - for dimensions beyond query_smem_elems - loadAndComputeDist lc(dist, - compute_dist); - for (int pos = shm_assisted_dim; pos < full_warps_along_dim; pos += WarpSize) { - lc.runLoadShflAndCompute(data, query, pos, lane_id); - } - lc.runLoadShflAndComputeRemainder(data, query, lane_id, dim, full_warps_along_dim); - } else { - // when shm_assisted_dim == full_warps_along_dim < dim - if (valid) { - loadAndComputeDist<1, decltype(compute_dist), Veclen, T, AccT> lc(dist, compute_dist); - for (int pos = full_warps_along_dim; pos < dim; - pos += Veclen, data += kIndexGroupSize * Veclen) { - lc.runLoadShmemCompute(data, query_shared, lane_id, pos); - } - } - } - - // Enqueue one element per thread - const float val = valid ? static_cast(dist) : block_sort_t::queue_t::kDummy; - const size_t idx = valid ? static_cast(list_indices_ptrs[list_id][vec_id]) : 0; - queue.add(val, idx); - } - } - } - - // finalize and store selected neighbours - __syncthreads(); - queue.done(interleaved_scan_kernel_smem); - queue.store(distances, neighbors, post_process); -} - -/** - * Configure the gridDim.x to maximize GPU occupancy, but reduce the output size - */ -template -uint32_t configure_launch_x(uint32_t numQueries, uint32_t n_probes, int32_t sMemSize, T func) -{ - int dev_id; - RAFT_CUDA_TRY(cudaGetDevice(&dev_id)); - int num_sms; - RAFT_CUDA_TRY(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); - int num_blocks_per_sm = 0; - RAFT_CUDA_TRY(cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &num_blocks_per_sm, func, kThreadsPerBlock, sMemSize)); - - size_t min_grid_size = num_sms * num_blocks_per_sm; - size_t min_grid_x = ceildiv(min_grid_size, numQueries); - return min_grid_x > n_probes ? n_probes : static_cast(min_grid_x); -} - -template -void launch_kernel(Lambda lambda, - PostLambda post_process, - const index& index, - const T* queries, - const uint32_t* coarse_index, - const uint32_t num_queries, - const uint32_t n_probes, - const uint32_t k, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - RAFT_EXPECTS(Veclen == index.veclen(), - "Configured Veclen does not match the index interleaving pattern."); - constexpr auto kKernel = - interleaved_scan_kernel; - const int max_query_smem = 16384; - int query_smem_elems = - std::min(max_query_smem / sizeof(T), Pow2::roundUp(index.dim())); - int smem_size = query_smem_elems * sizeof(T); - constexpr int kSubwarpSize = std::min(Capacity, WarpSize); - auto block_merge_mem = - raft::matrix::detail::select::warpsort::calc_smem_size_for_block_wide( - kThreadsPerBlock / kSubwarpSize, k); - smem_size += std::max(smem_size, block_merge_mem); - - // power-of-two less than cuda limit (for better addr alignment) - constexpr uint32_t kMaxGridY = 32768; - - if (grid_dim_x == 0) { - grid_dim_x = configure_launch_x(std::min(kMaxGridY, num_queries), n_probes, smem_size, kKernel); - return; - } - - for (uint32_t query_offset = 0; query_offset < num_queries; query_offset += kMaxGridY) { - uint32_t grid_dim_y = std::min(kMaxGridY, num_queries - query_offset); - dim3 grid_dim(grid_dim_x, grid_dim_y, 1); - dim3 block_dim(kThreadsPerBlock); - RAFT_LOG_TRACE( - "Launching the ivf-flat interleaved_scan_kernel (%d, %d, 1) x (%d, 1, 1), n_probes = %d, " - "smem_size = %d", - grid_dim.x, - grid_dim.y, - block_dim.x, - n_probes, - smem_size); - kKernel<<>>(lambda, - post_process, - query_smem_elems, - queries, - coarse_index, - index.inds_ptrs().data_handle(), - index.data_ptrs().data_handle(), - index.list_sizes().data_handle(), - n_probes, - k, - index.dim(), - neighbors, - distances); - queries += grid_dim_y * index.dim(); - neighbors += grid_dim_y * grid_dim_x * k; - distances += grid_dim_y * grid_dim_x * k; - } -} - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - const auto diff = x - y; - acc += diff * diff; - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(uint32_t& acc, uint32_t x, uint32_t y) - { - if constexpr (Veclen > 1) { - const auto diff = __vabsdiffu4(x, y); - acc = dp4a(diff, diff, acc); - } else { - const auto diff = __usad(x, y, 0u); - acc += diff * diff; - } - } -}; - -template -struct euclidean_dist { - __device__ __forceinline__ void operator()(int32_t& acc, int32_t x, int32_t y) - { - if constexpr (Veclen > 1) { - // Note that we enforce here that the unsigned version of dp4a is used, because the difference - // between two int8 numbers can be greater than 127 and therefore represented as a negative - // number in int8. Casting from int8 to int32 would yield incorrect results, while casting - // from uint8 to uint32 is correct. - const auto diff = __vabsdiffs4(x, y); - acc = dp4a(diff, diff, static_cast(acc)); - } else { - const auto diff = x - y; - acc += diff * diff; - } - } -}; - -template -struct inner_prod_dist { - __device__ __forceinline__ void operator()(AccT& acc, AccT x, AccT y) - { - if constexpr (Veclen > 1 && (std::is_same_v || std::is_same_v)) { - acc = dp4a(x, y, acc); - } else { - acc += x * y; - } - } -}; - -/** Select the distance computation function and forward the rest of the arguments. */ -template -void launch_with_fixed_consts(raft::distance::DistanceType metric, Args&&... args) -{ - switch (metric) { - case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::L2SqrtExpanded: - case raft::distance::DistanceType::L2SqrtUnexpanded: - return launch_kernel, - raft::sqrt_op>({}, {}, std::forward(args)...); - case raft::distance::DistanceType::InnerProduct: - return launch_kernel, - raft::identity_op>({}, {}, std::forward(args)...); - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); - } -} - -/** - * Lift the `capacity` and `veclen` parameters to the template level, - * forward the rest of the arguments unmodified to `launch_interleaved_scan_kernel`. - */ -template (1, 16 / sizeof(T))> -struct select_interleaved_scan_kernel { - /** - * Recursively reduce the `Capacity` and `Veclen` parameters until they match the - * corresponding runtime arguments. - * By default, this recursive process starts with maximum possible values of the - * two parameters and ends with both values equal to 1. - */ - template - static inline void run(int capacity, int veclen, bool select_min, Args&&... args) - { - if constexpr (Capacity > 1) { - if (capacity * 2 <= Capacity) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); - } - } - if constexpr (Veclen > 1) { - if (veclen * 2 <= Veclen) { - return select_interleaved_scan_kernel::run( - capacity, veclen, select_min, std::forward(args)...); - } - } - // NB: this is the limitation of the warpsort structures that use a huge number of - // registers (used in the main kernel here). - RAFT_EXPECTS(capacity == Capacity, - "Capacity must be power-of-two not bigger than the maximum allowed size " - "matrix::detail::select::warpsort::kMaxCapacity (%d).", - matrix::detail::select::warpsort::kMaxCapacity); - RAFT_EXPECTS( - veclen == Veclen, - "Veclen must be power-of-two not bigger than the maximum allowed size for this data type."); - if (select_min) { - launch_with_fixed_consts(std::forward(args)...); - } else { - launch_with_fixed_consts(std::forward(args)...); - } - } -}; - -/** - * @brief Configure and launch an appropriate template instance of the interleaved scan kernel. - * - * @tparam T value type - * @tparam AccT accumulated type - * @tparam IdxT type of the indices - * - * @param index previously built ivf-flat index - * @param[in] queries device pointer to the query vectors [batch_size, dim] - * @param[in] coarse_query_results device pointer to the cluster (list) ids [batch_size, n_probes] - * @param n_queries batch size - * @param metric type of the measured distance - * @param n_probes number of nearest clusters to query - * @param k number of nearest neighbors. - * NB: the maximum value of `k` is limited statically by `kMaxCapacity`. - * @param select_min whether to select nearest (true) or furthest (false) points w.r.t. the given - * metric. - * @param[out] neighbors device pointer to the result indices for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[out] distances device pointer to the result distances for each query and cluster - * [batch_size, grid_dim_x, k] - * @param[inout] grid_dim_x number of blocks launched across all n_probes clusters; - * (one block processes one or more probes, hence: 1 <= grid_dim_x <= n_probes) - * @param stream - */ -template -void ivfflat_interleaved_scan(const index& index, - const T* queries, - const uint32_t* coarse_query_results, - const uint32_t n_queries, - const raft::distance::DistanceType metric, - const uint32_t n_probes, - const uint32_t k, - const bool select_min, - IdxT* neighbors, - float* distances, - uint32_t& grid_dim_x, - rmm::cuda_stream_view stream) -{ - const int capacity = bound_by_power_of_two(k); - select_interleaved_scan_kernel::run(capacity, - index.veclen(), - select_min, - metric, - index, - queries, - coarse_query_results, - n_queries, - n_probes, - k, - neighbors, - distances, - grid_dim_x, - stream); -} - template void search_impl(raft::device_resources const& handle, - const index& index, + const raft::neighbors::ivf_flat::index& index, const T* queries, uint32_t n_queries, uint32_t k, diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh new file mode 100644 index 0000000000..0d5ca90297 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -0,0 +1,183 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // __half +#include // raft::distance::DistanceType +#include // raft::neighbors::ivf_pq::detail::fp_8bit +#include // raft::neighbors::ivf_pq::codebook_gen +#include // RAFT_EXPLICIT +#include // rmm::cuda_stream_view + +#ifdef RAFT_EXPLICIT_INSTANTIATE_ONLY + +namespace raft::neighbors::ivf_pq::detail { + +// is_local_topk_feasible is not inline here, because we would have to define it +// here as well. That would run the risk of the definitions here and in the +// -inl.cuh header diverging. +auto is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool; + +template +__global__ void compute_similarity_kernel(uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) RAFT_EXPLICIT; + +// The signature of the kernel defined by a minimal set of template parameters +template +using compute_similarity_kernel_t = + decltype(&compute_similarity_kernel); + +template +struct selected { + compute_similarity_kernel_t kernel; + dim3 grid_dim; + dim3 block_dim; + size_t smem_size; + size_t device_lut_size; +}; + +template +void compute_similarity_run(selected s, + rmm::cuda_stream_view stream, + uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) RAFT_EXPLICIT; + +/** + * Use heuristics to choose an optimal instance of the search kernel. + * It selects among a few kernel variants (with/out using shared mem for + * lookup tables / precomputed distances) and tries to choose the block size + * to maximize kernel occupancy. + * + * @param manage_local_topk + * whether use the fused calculate+select or just calculate the distances for each + * query and probed cluster. + * + * @param locality_hint + * beyond this limit do not consider increasing the number of active blocks per SM + * would improve locality anymore. + */ +template +auto compute_similarity_select(const cudaDeviceProp& dev_props, + bool manage_local_topk, + int locality_hint, + double preferred_shmem_carveout, + uint32_t pq_bits, + uint32_t pq_dim, + uint32_t precomp_data_count, + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) -> selected RAFT_EXPLICIT; + +} // namespace raft::neighbors::ivf_pq::detail + +#endif // RAFT_EXPLICIT_INSTANTIATE_ONLY + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + extern template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + extern template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh new file mode 100644 index 0000000000..7573e2ca13 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-inl.cuh @@ -0,0 +1,845 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // raft::distance::DistanceType +#include // matrix::detail::select::warpsort::warp_sort_distributed +#include // dummy_block_sort_t +#include // codebook_gen +#include // RAFT_CUDA_TRY +#include // raft::atomicMin +#include // raft::Pow2 +#include // raft::TxN_t +#include // rmm::cuda_stream_view + +namespace raft::neighbors::ivf_pq::detail { + +/** + * Maximum value of k for the fused calculate & select in ivfpq. + * + * If runtime value of k is larger than this, the main search operation + * is split into two kernels (per batch, first calculate distance, then select top-k). + */ +static constexpr int kMaxCapacity = 128; +static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), + "kMaxCapacity must be a power of two, not smaller than the WarpSize."); + +// using weak attribute here, because it may be compiled multiple times. +auto __attribute__((weak)) is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) + -> bool +{ + if (k > kMaxCapacity) { return false; } // warp_sort not possible + if (n_probes <= 16) { return false; } // too few clusters + if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small + return true; +} + +template +struct pq_block_sort { + using type = matrix::detail::select::warpsort:: + block_sort; +}; + +template +struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { + using type = dummy_block_sort_t; +}; + +template +using block_sort_t = typename pq_block_sort::type; + +/** + * Estimate a carveout value as expected by `cudaFuncAttributePreferredSharedMemoryCarveout` + * (which does not take into account `reservedSharedMemPerBlock`), + * given by a desired schmem-L1 split and a per-block memory requirement in bytes. + * + * NB: As per the programming guide, the memory carveout setting is just a hint for the driver; it's + * free to choose any shmem-L1 configuration it deems appropriate. For example, if you set the + * carveout to zero, it will choose a non-zero config that will allow to run at least one active + * block per SM. + * + * @param shmem_fraction + * a fraction representing a desired split (shmem / (shmem + L1)) [0, 1]. + * @param shmem_per_block + * a shared memory usage per block (dynamic + static shared memory sizes), in bytes. + * @param dev_props + * device properties. + * @return + * a carveout value in percents [0, 100]. + */ +constexpr inline auto estimate_carveout(double shmem_fraction, + size_t shmem_per_block, + const cudaDeviceProp& dev_props) -> int +{ + using shmem_unit = Pow2<128>; + size_t m = shmem_unit::roundUp(shmem_per_block); + size_t r = dev_props.reservedSharedMemPerBlock; + size_t s = dev_props.sharedMemPerMultiprocessor; + return (size_t(100 * s * m * shmem_fraction) - (m - 1) * r) / (s * (m + r)); +} + +/* Manually unrolled loop over a chunk of pq_dataset that fits into one VecT. */ +template +__device__ __forceinline__ void ivfpq_compute_chunk(OutT& score /* NOLINT */, + typename VecT::math_t& pq_code, + const VecT& pq_codes, + const LutT*& lut_head, + const LutT*& lut_end) +{ + if constexpr (CheckBounds) { + if (lut_head >= lut_end) { return; } + } + constexpr uint32_t kTotalBits = 8 * sizeof(typename VecT::math_t); + constexpr uint32_t kPqShift = 1u << PqBits; + constexpr uint32_t kPqMask = kPqShift - 1u; + if constexpr (BitsLeft >= PqBits) { + uint8_t code = pq_code & kPqMask; + pq_code >>= PqBits; + score += OutT(lut_head[code]); + lut_head += kPqShift; + return ivfpq_compute_chunk( + score, pq_code, pq_codes, lut_head, lut_end); + } else if constexpr (Ix < VecT::Ratio) { + uint8_t code = pq_code; + pq_code = pq_codes.val.data[Ix]; + constexpr uint32_t kRemBits = PqBits - BitsLeft; + constexpr uint32_t kRemMask = (1u << kRemBits) - 1u; + code |= (pq_code & kRemMask) << BitsLeft; + pq_code >>= kRemBits; + score += OutT(lut_head[code]); + lut_head += kPqShift; + return ivfpq_compute_chunk(score, pq_code, pq_codes, lut_head, lut_end); + } +} + +/* Compute the similarity for one vector in the pq_dataset */ +template +__device__ auto ivfpq_compute_score(uint32_t pq_dim, + const typename VecT::io_t* pq_head, + const LutT* lut_scores, + OutT early_stop_limit) -> OutT +{ + constexpr uint32_t kChunkSize = sizeof(VecT) * 8u / PqBits; + auto lut_head = lut_scores; + auto lut_end = lut_scores + (pq_dim << PqBits); + VecT pq_codes; + OutT score{0}; + for (; pq_dim >= kChunkSize; pq_dim -= kChunkSize) { + *pq_codes.vectorized_data() = *pq_head; + pq_head += kIndexGroupSize; + typename VecT::math_t pq_code = 0; + ivfpq_compute_chunk( + score, pq_code, pq_codes, lut_head, lut_end); + // Early stop when it makes sense (otherwise early_stop_limit is kDummy/infinity). + if (score >= early_stop_limit) { return score; } + } + if (pq_dim > 0) { + *pq_codes.vectorized_data() = *pq_head; + typename VecT::math_t pq_code = 0; + ivfpq_compute_chunk( + score, pq_code, pq_codes, lut_head, lut_end); + } + return score; +} + +/** + * The main kernel that computes similarity scores across multiple queries and probes. + * When `Capacity > 0`, it also selects top K candidates for each query and probe + * (which need to be merged across probes afterwards). + * + * Each block processes a (query, probe) pair: it calculates the distance between the single query + * vector and all the dataset vector in the cluster that we are probing. + * + * @tparam OutT + * The output type - distances. + * @tparam LutT + * The lookup table element type (lut_scores). + * @tparam PqBits + * The bit length of an encoded vector element after compression by PQ + * (NB: pq_book_size = 1 << PqBits). + * @tparam Capacity + * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. + * @tparam PrecompBaseDiff + * Defines whether we should precompute part of the distance and keep it in shared memory + * before the main part (score calculation) to increase memory usage efficiency in the latter. + * For L2, this is the distance between the query and the cluster center. + * @tparam EnableSMemLut + * Defines whether to use the shared memory for the lookup table (`lut_scores`). + * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) + * at the cost of reducing global memory reading throughput. + * + * @param n_rows the number of records in the dataset + * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). + * @param n_probes the number of clusters to search for each query + * @param pq_dim + * The dimensionality of an encoded vector after compression by PQ. + * @param n_queries the number of queries. + * @param metric the distance type. + * @param codebook_kind Defines the way PQ codebooks have been trained. + * @param topk the `k` in the select top-k. + * @param max_samples the size of the output for a single query. + * @param cluster_centers + * The device pointer to the cluster centers in the original space (NB: after rotation) + * [n_clusters, dim]. + * @param pq_centers + * The device pointer to the cluster centers in the PQ space + * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. + * @param pq_dataset + * The device pointer to the PQ index (data) [n_rows, ...]. + * @param cluster_labels + * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. + * @param _chunk_indices + * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. + * @param queries + * The device pointer to the queries (NB: after rotation) [n_queries, dim]. + * @param index_list + * An optional device pointer to the enforced order of search [n_queries, n_probes]. + * One can pass reordered indices here to try to improve data reading locality. + * @param lut_scores + * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. + * Ignored when `EnableSMemLut == true`. + * @param _out_scores + * The device pointer to the output scores + * [n_queries, max_samples] or [n_queries, n_probes, topk]. + * @param _out_indices + * The device pointer to the output indices [n_queries, n_probes, topk]. + * These are the indices of the records as they appear in the database view formed by the probed + * clusters / defined by the `_chunk_indices`. + * The indices can have values within the range [0, max_samples). + * Ignored when `Capacity == 0`. + */ +template +__global__ void compute_similarity_kernel(uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) +{ + /* Shared memory: + + * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) + * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 + * topk::block_sort: some amount of shared memory, but overlaps with the rest: + block_sort only needs shared memory for `.done()` operation, which can come very last. + */ + extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT + constexpr bool kManageLocalTopK = Capacity > 0; + + constexpr uint32_t PqShift = 1u << PqBits; // NOLINT + constexpr uint32_t PqMask = PqShift - 1u; // NOLINT + + const uint32_t pq_len = dim / pq_dim; + const uint32_t lut_size = pq_dim * PqShift; + + if constexpr (EnableSMemLut) { + lut_scores = reinterpret_cast(smem_buf); + } else { + lut_scores += lut_size * blockIdx.x; + } + + float* base_diff = nullptr; + if constexpr (PrecompBaseDiff) { + if constexpr (EnableSMemLut) { + base_diff = reinterpret_cast(lut_scores + lut_size); + } else { + base_diff = reinterpret_cast(smem_buf); + } + } + + for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { + if (ib >= gridDim.x) { + // sync shared memory accesses on the second and further iterations + __syncthreads(); + } + uint32_t query_ix; + uint32_t probe_ix; + if (index_list == nullptr) { + query_ix = ib % n_queries; + probe_ix = ib / n_queries; + } else { + auto ordered_ix = index_list[ib]; + query_ix = ordered_ix / n_probes; + probe_ix = ordered_ix % n_probes; + } + + const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); + const float* query = queries + (dim * query_ix); + OutT* out_scores; + uint32_t* out_indices = nullptr; + if constexpr (kManageLocalTopK) { + // Store topk calculated distances to out_scores (and its indices to out_indices) + out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix)); + out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix)); + } else { + // Store all calculated distances to out_scores + out_scores = _out_scores + max_samples * query_ix; + } + uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; + const float* cluster_center = cluster_centers + (dim * label); + const float* pq_center; + if (codebook_kind == codebook_gen::PER_SUBSPACE) { + pq_center = pq_centers; + } else { + pq_center = pq_centers + (pq_len << PqBits) * label; + } + + if constexpr (PrecompBaseDiff) { + // Reduce number of memory reads later by pre-computing parts of the score + switch (metric) { + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::L2Expanded: { + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { + base_diff[i] = query[i] - cluster_center[i]; + } + } break; + case distance::DistanceType::InnerProduct: { + float2 pvals; + for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { + pvals.x = query[i]; + pvals.y = cluster_center[i] * pvals.x; + reinterpret_cast(base_diff)[i] = pvals; + } + } break; + default: __builtin_unreachable(); + } + __syncthreads(); + } + + { + // Create a lookup table + // For each subspace, the lookup table stores the distance between the actual query vector + // (projected into the subspace) and all possible pq vectors in that subspace. + for (uint32_t i = threadIdx.x; i < lut_size; i += blockDim.x) { + const uint32_t i_pq = i >> PqBits; + uint32_t j = i_pq * pq_len; + const uint32_t j_end = pq_len + j; + auto cur_pq_center = pq_center + (i & PqMask) + + (codebook_kind == codebook_gen::PER_SUBSPACE ? j * PqShift : 0u); + float score = 0.0; + do { + float pq_c = *cur_pq_center; + cur_pq_center += PqShift; + switch (metric) { + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::L2Expanded: { + float diff; + if constexpr (PrecompBaseDiff) { + diff = base_diff[j]; + } else { + diff = query[j] - cluster_center[j]; + } + diff -= pq_c; + score += diff * diff; + } break; + case distance::DistanceType::InnerProduct: { + // NB: we negate the scores as we hardcoded select-topk to always compute the minimum + float q; + if constexpr (PrecompBaseDiff) { + float2 pvals = reinterpret_cast(base_diff)[j]; + q = pvals.x; + score -= pvals.y; + } else { + q = query[j]; + score -= q * cluster_center[j]; + } + score -= q * pq_c; + } break; + default: __builtin_unreachable(); + } + } while (++j < j_end); + lut_scores[i] = LutT(score); + } + } + + // Define helper types for efficient access to the pq_dataset, which is stored in an interleaved + // format. The chunks of PQ data are stored in kIndexGroupVecLen-bytes-long chunks, interleaved + // in groups of kIndexGroupSize elems (which is normally equal to the warp size) for the fastest + // possible access by thread warps. + // + // Consider one record in the pq_dataset is `pq_dim * pq_bits`-bit-long. + // Assuming `kIndexGroupVecLen = 16`, one chunk of data read by a thread at once is 128-bits. + // Then, such a chunk contains `chunk_size = 128 / pq_bits` record elements, and the record + // consists of `ceildiv(pq_dim, chunk_size)` chunks. The chunks are interleaved in groups of 32, + // so that the warp can achieve the best coalesced read throughput. + using group_align = Pow2; + using vec_align = Pow2; + using local_topk_t = block_sort_t; + using op_t = uint32_t; + using vec_t = TxN_t; + + uint32_t sample_offset = 0; + if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } + uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; + uint32_t n_samples_aligned = group_align::roundUp(n_samples); + constexpr uint32_t kChunkSize = (kIndexGroupVecLen * 8u) / PqBits; + uint32_t pq_line_width = div_rounding_up_unsafe(pq_dim, kChunkSize) * kIndexGroupVecLen; + auto pq_thread_data = pq_dataset[label] + group_align::roundDown(threadIdx.x) * pq_line_width + + group_align::mod(threadIdx.x) * vec_align::Value; + pq_line_width *= blockDim.x; + + constexpr OutT kDummy = upper_bound(); + OutT query_kth = kDummy; + if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } + local_topk_t block_topk(topk, nullptr, query_kth); + OutT early_stop_limit = kDummy; + switch (metric) { + // If the metric is non-negative, we can use the query_kth approximation as an early stop + // threshold to skip some iterations when computing the score. Add such metrics here. + case distance::DistanceType::L2SqrtExpanded: + case distance::DistanceType::L2Expanded: { + early_stop_limit = query_kth; + } break; + default: break; + } + + // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score + __threadfence_block(); + __syncthreads(); + + // Compute a distance for each sample + for (uint32_t i = threadIdx.x; i < n_samples_aligned; + i += blockDim.x, pq_thread_data += pq_line_width) { + OutT score = kDummy; + bool valid = i < n_samples; + if (valid) { + score = ivfpq_compute_score( + pq_dim, + reinterpret_cast(pq_thread_data), + lut_scores, + early_stop_limit); + } + if constexpr (kManageLocalTopK) { + block_topk.add(score, sample_offset + i); + } else { + if (valid) { out_scores[sample_offset + i] = score; } + } + } + if constexpr (kManageLocalTopK) { + // sync threads before the topk merging operation, because we reuse smem_buf + __syncthreads(); + block_topk.done(smem_buf); + block_topk.store(out_scores, out_indices); + if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); } + } else { + // fill in the rest of the out_scores with dummy values + if (probe_ix + 1 == n_probes) { + for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; + i += blockDim.x) { + out_scores[i] = kDummy; + } + } + } + } +} + +// The signature of the kernel defined by a minimal set of template parameters +template +using compute_similarity_kernel_t = + decltype(&compute_similarity_kernel); + +// The config struct lifts the runtime parameters to the template parameters +template +struct compute_similarity_kernel_config { + public: + static auto get(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t + { + return kernel_choose_bits(pq_bits, k_max); + } + + private: + static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) + -> compute_similarity_kernel_t + { + switch (pq_bits) { + case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); + case 5: return kernel_try_capacity<5, kMaxCapacity>(k_max); + case 6: return kernel_try_capacity<6, kMaxCapacity>(k_max); + case 7: return kernel_try_capacity<7, kMaxCapacity>(k_max); + case 8: return kernel_try_capacity<8, kMaxCapacity>(k_max); + default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); + } + } + + template + static auto kernel_try_capacity(uint32_t k_max) -> compute_similarity_kernel_t + { + if constexpr (Capacity > 0) { + if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } + } + if constexpr (Capacity > 1) { + if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } + } + return compute_similarity_kernel; + } +}; + +// A standalone accessor function was necessary to make sure template +// instantiation work correctly. This accessor function is not used anymore and +// may be removed. +template +auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) + -> compute_similarity_kernel_t +{ + return compute_similarity_kernel_config::get(pq_bits, + k_max); +} + +/** Estimate the occupancy for the given kernel on the given device. */ +template +struct occupancy_t { + using shmem_unit = Pow2<128>; + + int blocks_per_sm = 0; + double occupancy = 0.0; + double shmem_use = 1.0; + + inline occupancy_t() = default; + inline occupancy_t(size_t smem, + uint32_t n_threads, + compute_similarity_kernel_t kernel, + const cudaDeviceProp& dev_props) + { + RAFT_CUDA_TRY( + cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, n_threads, smem)); + occupancy = double(blocks_per_sm * n_threads) / double(dev_props.maxThreadsPerMultiProcessor); + shmem_use = double(shmem_unit::roundUp(smem) * blocks_per_sm) / + double(dev_props.sharedMemPerMultiprocessor); + } +}; + +template +struct selected { + compute_similarity_kernel_t kernel; + dim3 grid_dim; + dim3 block_dim; + size_t smem_size; + size_t device_lut_size; +}; + +template +void compute_similarity_run(selected s, + rmm::cuda_stream_view stream, + uint32_t n_rows, + uint32_t dim, + uint32_t n_probes, + uint32_t pq_dim, + uint32_t n_queries, + distance::DistanceType metric, + codebook_gen codebook_kind, + uint32_t topk, + uint32_t max_samples, + const float* cluster_centers, + const float* pq_centers, + const uint8_t* const* pq_dataset, + const uint32_t* cluster_labels, + const uint32_t* _chunk_indices, + const float* queries, + const uint32_t* index_list, + float* query_kths, + LutT* lut_scores, + OutT* _out_scores, + uint32_t* _out_indices) +{ + s.kernel<<>>(n_rows, + dim, + n_probes, + pq_dim, + n_queries, + metric, + codebook_kind, + topk, + max_samples, + cluster_centers, + pq_centers, + pq_dataset, + cluster_labels, + _chunk_indices, + queries, + index_list, + query_kths, + lut_scores, + _out_scores, + _out_indices); + RAFT_CHECK_CUDA(stream); +} + +/** + * Use heuristics to choose an optimal instance of the search kernel. + * It selects among a few kernel variants (with/out using shared mem for + * lookup tables / precomputed distances) and tries to choose the block size + * to maximize kernel occupancy. + * + * @param manage_local_topk + * whether use the fused calculate+select or just calculate the distances for each + * query and probed cluster. + * + * @param locality_hint + * beyond this limit do not consider increasing the number of active blocks per SM + * would improve locality anymore. + */ +template +auto compute_similarity_select(const cudaDeviceProp& dev_props, + bool manage_local_topk, + int locality_hint, + double preferred_shmem_carveout, + uint32_t pq_bits, + uint32_t pq_dim, + uint32_t precomp_data_count, + uint32_t n_queries, + uint32_t n_probes, + uint32_t topk) -> selected +{ + // Shared memory for storing the lookup table + size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); + // Shared memory for storing pre-computed pieces to speedup the lookup table construction + // (e.g. the distance between a cluster center and the query for L2). + size_t bdf_mem = sizeof(float) * precomp_data_count; + // Shared memory for the fused top-k component; it may overlap with the other uses of shared + // memory and depends on the number of threads. + struct ltk_mem_t { + uint32_t subwarp_size; + uint32_t topk; + bool manage_local_topk; + ltk_mem_t(bool manage_local_topk, uint32_t topk) + : manage_local_topk(manage_local_topk), topk(topk) + { + subwarp_size = WarpSize; + while (topk * 2 <= subwarp_size) { + subwarp_size /= 2; + } + } + + [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t + { + return manage_local_topk + ? matrix::detail::select::warpsort::template calc_smem_size_for_block_wide( + n_threads / subwarp_size, topk) + : 0; + } + } ltk_mem{manage_local_topk, topk}; + + // Total amount of work; should be enough to occupy the GPU. + uint32_t n_blocks = n_queries * n_probes; + + // The minimum block size we may want: + // 1. It's a power-of-two for efficient L1 caching of pq_centers values + // (multiples of `1 << pq_bits`). + // 2. It should be large enough to fully utilize an SM. + uint32_t n_threads_min = WarpSize; + while (dev_props.maxBlocksPerMultiProcessor * int(n_threads_min) < + dev_props.maxThreadsPerMultiProcessor) { + n_threads_min *= 2; + } + // Further increase the minimum block size to make sure full device occupancy + // (NB: this may lead to `n_threads_min` being larger than the kernel's maximum) + while (int(n_blocks * n_threads_min) < + dev_props.multiProcessorCount * dev_props.maxThreadsPerMultiProcessor && + int(n_threads_min) < dev_props.maxThreadsPerBlock) { + n_threads_min *= 2; + } + // Even further, increase it to allow less blocks per SM if there not enough queries. + // With this, we reduce the chance of different clusters being processed by two blocks + // on the same SM and thus improve the data locality for L1 caching. + while (int(n_queries * n_threads_min) < dev_props.maxThreadsPerMultiProcessor && + int(n_threads_min) < dev_props.maxThreadsPerBlock) { + n_threads_min *= 2; + } + + // Granularity of changing the number of threads when computing the maximum block size. + // It's good to have it multiple of the PQ book width. + uint32_t n_threads_gty = round_up_safe(1u << pq_bits, WarpSize); + + /* + Shared memory / L1 cache balance is the main limiter of this kernel. + The more blocks per SM we launch, the more shared memory we need. Besides that, we have + three versions of the kernel varying in performance and shmem usage. + + We try the most demanding and the fastest kernel first, trying to maximize occupancy with + the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further + optimize occupancy and data locality for the L1 cache. + */ + auto conf_fast = get_compute_similarity_kernel; + auto conf_no_basediff = get_compute_similarity_kernel; + auto conf_no_smem_lut = get_compute_similarity_kernel; + auto topk_or_zero = manage_local_topk ? topk : 0u; + std::array candidates{std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), + std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), + std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)}; + + // we may allow slightly lower than 100% occupancy; + constexpr double kTargetOccupancy = 0.75; + // This struct is used to select the better candidate + occupancy_t selected_perf{}; + selected selected_config; + for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { + if (smem_size_const > dev_props.sharedMemPerBlockOptin) { + // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. + continue; + } + + // First, we set the carveout hint to the preferred value. The driver will increase this if + // needed to run at least one block per SM. At the same time, if more blocks fit into one SM, + // this carveout value will limit the calculated occupancy. When we're done selecting the best + // launch configuration, we will tighten the carveout once more, based on the final memory + // usage and occupancy. + const int max_carveout = + estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props); + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); + + // Get the theoretical maximum possible number of threads per block + cudaFuncAttributes kernel_attrs; + RAFT_CUDA_TRY(cudaFuncGetAttributes(&kernel_attrs, kernel)); + uint32_t n_threads = round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); + + // Actual required shmem depens on the number of threads + size_t smem_size = max(smem_size_const, ltk_mem(n_threads)); + + // Make sure the kernel can get enough shmem. + cudaError_t cuda_status = + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (cuda_status != cudaSuccess) { + RAFT_EXPECTS( + cuda_status == cudaGetLastError(), + "Tried to reset the expected cuda error code, but it didn't match the expectation"); + // Failed to request enough shmem for the kernel. Skip the candidate. + continue; + } + + occupancy_t cur(smem_size, n_threads, kernel, dev_props); + if (cur.blocks_per_sm <= 0) { + // For some reason, we still cannot make this kernel run. Skip the candidate. + continue; + } + + { + // Try to reduce the number of threads to increase occupancy and data locality + auto n_threads_tmp = n_threads_min; + while (n_threads_tmp * 2 < n_threads) { + n_threads_tmp *= 2; + } + if (n_threads_tmp < n_threads) { + while (n_threads_tmp >= n_threads_min) { + auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); + occupancy_t tmp(smem_size_tmp, n_threads_tmp, kernel, dev_props); + bool select_it = false; + if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { + // Normally, the smaller the block the better for L1 cache hit rate. + // Hence, the occupancy should be "just good enough" + select_it = tmp.occupancy >= min(kTargetOccupancy, cur.occupancy); + } else if (lut_is_in_shmem) { + // If we don't have enough repeating probes (locality_hint < tmp.blocks_per_sm), + // the locality is not going to improve with increasing the number of blocks per SM. + // Hence, the only metric here is the occupancy. + bool improves_occupancy = tmp.occupancy > cur.occupancy; + // Otherwise, the performance still improves with a smaller block size, + // given there is enough work to do + bool improves_parallelism = + tmp.occupancy == cur.occupancy && + 7u * tmp.blocks_per_sm * dev_props.multiProcessorCount <= n_blocks; + select_it = improves_occupancy || improves_parallelism; + } else { + // If we don't use shared memory for the lookup table, increasing the number of blocks + // is very taxing on the global memory usage. + // In this case, the occupancy must increase a lot to make it worth the cost. + select_it = tmp.occupancy >= min(1.0, cur.occupancy / kTargetOccupancy); + } + if (select_it) { + n_threads = n_threads_tmp; + smem_size = smem_size_tmp; + cur = tmp; + } + n_threads_tmp /= 2; + } + } + } + + { + if (selected_perf.occupancy <= 0.0 // no candidate yet + || (selected_perf.occupancy < cur.occupancy * kTargetOccupancy && + selected_perf.shmem_use >= cur.shmem_use) // much improved occupancy + ) { + selected_perf = cur; + if (lut_is_in_shmem) { + selected_config = { + kernel, dim3(n_blocks, 1, 1), dim3(n_threads, 1, 1), smem_size, size_t(0)}; + } else { + // When the global memory is used for the lookup table, we need to minimize the grid + // size; otherwise, the kernel may quickly run out of memory. + auto n_blocks_min = + std::min(n_blocks, cur.blocks_per_sm * dev_props.multiProcessorCount); + selected_config = {kernel, + dim3(n_blocks_min, 1, 1), + dim3(n_threads, 1, 1), + smem_size, + size_t(n_blocks_min) * size_t(pq_dim << pq_bits)}; + } + // Actual shmem/L1 split wildly rounds up the specified preferred carveout, so we set here + // a rather conservative bar; most likely, the kernel gets more shared memory than this, + // and the occupancy doesn't get hurt. + auto carveout = std::min(max_carveout, std::ceil(100.0 * cur.shmem_use)); + RAFT_CUDA_TRY( + cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, carveout)); + if (cur.occupancy >= kTargetOccupancy) { break; } + } else if (selected_perf.occupancy > 0.0) { + // If we found a reasonable candidate on a previous iteration, and this one is not better, + // then don't try any more candidates because they are much slower anyway. + break; + } + } + } + + RAFT_EXPECTS(selected_perf.occupancy > 0.0, + "Couldn't determine a working kernel launch configuration."); + + return selected_config; +} + +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh new file mode 100644 index 0000000000..d987c0d4ed --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity.cuh @@ -0,0 +1,25 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#if !defined(RAFT_EXPLICIT_INSTANTIATE_ONLY) +#include "ivf_pq_compute_similarity-inl.cuh" +#endif + +#ifdef RAFT_COMPILED +#include "ivf_pq_compute_similarity-ext.cuh" +#endif diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh new file mode 100644 index 0000000000..a00b6a50ff --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_dummy_block_sort.cuh @@ -0,0 +1,39 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include // matrix::detail::select::warpsort::warp_sort_distributed + +/* + * This header file is a bit of an ugly duckling. The type dummy_block_sort is + * needed by both ivf_pq_search.cuh and ivf_pq_compute_similarity.cuh. + * + * I have decided to move it to it's own header file, which is overkill. Perhaps + * there is a nicer solution. + * + */ + +namespace raft::neighbors::ivf_pq::detail { + +template +struct dummy_block_sort_t { + using queue_t = matrix::detail::select::warpsort::warp_sort_distributed; + template + __device__ dummy_block_sort_t(int k, Args...){}; +}; + +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh new file mode 100644 index 0000000000..87f9bfb622 --- /dev/null +++ b/cpp/include/raft/neighbors/detail/ivf_pq_fp_8bit.cuh @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +#include + +namespace raft::neighbors::ivf_pq::detail { + +/** 8-bit floating-point storage type. + * + * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined + * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. + */ +template +struct fp_8bit { + static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); + constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT + constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT + + public: + uint8_t bitstring; + + HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} + HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} + HDI auto operator=(float fp) -> fp_8bit& + { + bitstring = float2fp_8bit(fp).bitstring; + return *this; + } + HDI explicit operator float() const { return fp_8bit2float(*this); } + HDI explicit operator half() const { return half(fp_8bit2float(*this)); } + + private: + static constexpr float kMin = 1.0f / float(1u << ExpMask); + static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); + + static HDI auto float2fp_8bit(float v) -> fp_8bit + { + if constexpr (Signed) { + auto u = fp_8bit(std::abs(v)).bitstring; + u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit + return fp_8bit(u); + } else { + // sic! all small and negative numbers are truncated to zero. + if (v < kMin) { return fp_8bit{static_cast(0)}; } + // protect from overflow + if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } + // the rest of possible float values should be within the normalized range + return fp_8bit{static_cast( + (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; + } + } + + static HDI auto fp_8bit2float(const fp_8bit& v) -> float + { + uint32_t u = v.bitstring; + if constexpr (Signed) { + u &= ~1; // zero the sign bit + } + float r; + *reinterpret_cast(&r) = + ((u << (15u + ExpBits)) + (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23)); + if constexpr (Signed) { // recover the sign bit + if (v.bitstring & 1) { r = -r; } + } + return r; + } +}; + +} // namespace raft::neighbors::ivf_pq::detail diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh index 4b6e6f5e31..53d1fd6290 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_search.cuh @@ -18,6 +18,9 @@ #include +#include +#include +#include #include #include @@ -49,79 +52,8 @@ namespace raft::neighbors::ivf_pq::detail { -/** - * Maximum value of k for the fused calculate & select in ivfpq. - * - * If runtime value of k is larger than this, the main search operation - * is split into two kernels (per batch, first calculate distance, then select top-k). - */ -static constexpr int kMaxCapacity = 128; -static_assert((kMaxCapacity >= 32) && !(kMaxCapacity & (kMaxCapacity - 1)), - "kMaxCapacity must be a power of two, not smaller than the WarpSize."); - using namespace raft::spatial::knn::detail; // NOLINT -/** 8-bit floating-point storage type. - * - * This is a custom type for the current IVF-PQ implementation. No arithmetic operations defined - * only conversion to and from fp32. This type is unrelated to the proposed FP8 specification. - */ -template -struct fp_8bit { - static_assert(ExpBits + uint8_t{Signed} <= 8, "The type does not fit in 8 bits."); - constexpr static uint32_t ExpMask = (1u << (ExpBits - 1u)) - 1u; // NOLINT - constexpr static uint32_t ValBits = 8u - ExpBits; // NOLINT - - public: - uint8_t bitstring; - - HDI explicit fp_8bit(uint8_t bs) : bitstring(bs) {} - HDI explicit fp_8bit(float fp) : fp_8bit(float2fp_8bit(fp).bitstring) {} - HDI auto operator=(float fp) -> fp_8bit& - { - bitstring = float2fp_8bit(fp).bitstring; - return *this; - } - HDI explicit operator float() const { return fp_8bit2float(*this); } - HDI explicit operator half() const { return half(fp_8bit2float(*this)); } - - private: - static constexpr float kMin = 1.0f / float(1u << ExpMask); - static constexpr float kMax = float(1u << (ExpMask + 1)) * (2.0f - 1.0f / float(1u << ValBits)); - - static HDI auto float2fp_8bit(float v) -> fp_8bit - { - if constexpr (Signed) { - auto u = fp_8bit(std::abs(v)).bitstring; - u = (u & 0xfeu) | uint8_t{v < 0}; // set the sign bit - return fp_8bit(u); - } else { - // sic! all small and negative numbers are truncated to zero. - if (v < kMin) { return fp_8bit{static_cast(0)}; } - // protect from overflow - if (v >= kMax) { return fp_8bit{static_cast(0xffu)}; } - // the rest of possible float values should be within the normalized range - return fp_8bit{static_cast( - (*reinterpret_cast(&v) + (ExpMask << 23u) - 0x3f800000u) >> (15u + ExpBits))}; - } - } - - static HDI auto fp_8bit2float(const fp_8bit& v) -> float - { - uint32_t u = v.bitstring; - if constexpr (Signed) { - u &= ~1; // zero the sign bit - } - float r; - *reinterpret_cast(&r) = - ((u << (15u + ExpBits)) + (0x3f800000u | (0x00400000u >> ValBits)) - (ExpMask << 23)); - if constexpr (Signed) { // recover the sign bit - if (v.bitstring & 1) { r = -r; } - } - return r; - } -}; - /** * Select the clusters to probe and, as a side-effect, translate the queries type `T -> float` * @@ -439,464 +371,6 @@ void postprocess_distances(float* out, // [n_queries, topk] } } -template -struct dummy_block_sort_t { - using queue_t = matrix::detail::select::warpsort::warp_sort_distributed; - template - __device__ dummy_block_sort_t(int k, Args...){}; -}; - -template -struct pq_block_sort { - using type = matrix::detail::select::warpsort:: - block_sort; -}; - -template -struct pq_block_sort<0, T, IdxT> : dummy_block_sort_t { - using type = dummy_block_sort_t; -}; - -template -using block_sort_t = typename pq_block_sort::type; - -/* Manually unrolled loop over a chunk of pq_dataset that fits into one VecT. */ -template -__device__ __forceinline__ void ivfpq_compute_chunk(OutT& score /* NOLINT */, - typename VecT::math_t& pq_code, - const VecT& pq_codes, - const LutT*& lut_head, - const LutT*& lut_end) -{ - if constexpr (CheckBounds) { - if (lut_head >= lut_end) { return; } - } - constexpr uint32_t kTotalBits = 8 * sizeof(typename VecT::math_t); - constexpr uint32_t kPqShift = 1u << PqBits; - constexpr uint32_t kPqMask = kPqShift - 1u; - if constexpr (BitsLeft >= PqBits) { - uint8_t code = pq_code & kPqMask; - pq_code >>= PqBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } else if constexpr (Ix < VecT::Ratio) { - uint8_t code = pq_code; - pq_code = pq_codes.val.data[Ix]; - constexpr uint32_t kRemBits = PqBits - BitsLeft; - constexpr uint32_t kRemMask = (1u << kRemBits) - 1u; - code |= (pq_code & kRemMask) << BitsLeft; - pq_code >>= kRemBits; - score += OutT(lut_head[code]); - lut_head += kPqShift; - return ivfpq_compute_chunk(score, pq_code, pq_codes, lut_head, lut_end); - } -} - -/* Compute the similarity for one vector in the pq_dataset */ -template -__device__ auto ivfpq_compute_score(uint32_t pq_dim, - const typename VecT::io_t* pq_head, - const LutT* lut_scores, - OutT early_stop_limit) -> OutT -{ - constexpr uint32_t kChunkSize = sizeof(VecT) * 8u / PqBits; - auto lut_head = lut_scores; - auto lut_end = lut_scores + (pq_dim << PqBits); - VecT pq_codes; - OutT score{0}; - for (; pq_dim >= kChunkSize; pq_dim -= kChunkSize) { - *pq_codes.vectorized_data() = *pq_head; - pq_head += kIndexGroupSize; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - // Early stop when it makes sense (otherwise early_stop_limit is kDummy/infinity). - if (score >= early_stop_limit) { return score; } - } - if (pq_dim > 0) { - *pq_codes.vectorized_data() = *pq_head; - typename VecT::math_t pq_code = 0; - ivfpq_compute_chunk( - score, pq_code, pq_codes, lut_head, lut_end); - } - return score; -} - -/** - * The main kernel that computes similarity scores across multiple queries and probes. - * When `Capacity > 0`, it also selects top K candidates for each query and probe - * (which need to be merged across probes afterwards). - * - * Each block processes a (query, probe) pair: it calculates the distance between the single query - * vector and all the dataset vector in the cluster that we are probing. - * - * @tparam OutT - * The output type - distances. - * @tparam LutT - * The lookup table element type (lut_scores). - * @tparam PqBits - * The bit length of an encoded vector element after compression by PQ - * (NB: pq_book_size = 1 << PqBits). - * @tparam Capacity - * Power-of-two; the maximum possible `k` in top-k. Value zero disables fused top-k search. - * @tparam PrecompBaseDiff - * Defines whether we should precompute part of the distance and keep it in shared memory - * before the main part (score calculation) to increase memory usage efficiency in the latter. - * For L2, this is the distance between the query and the cluster center. - * @tparam EnableSMemLut - * Defines whether to use the shared memory for the lookup table (`lut_scores`). - * Setting this to `false` allows to reduce the shared memory usage (and maximum data dim) - * at the cost of reducing global memory reading throughput. - * - * @param n_rows the number of records in the dataset - * @param dim the dimensionality of the data (NB: after rotation transform, i.e. `index.rot_dim()`). - * @param n_probes the number of clusters to search for each query - * @param pq_dim - * The dimensionality of an encoded vector after compression by PQ. - * @param n_queries the number of queries. - * @param metric the distance type. - * @param codebook_kind Defines the way PQ codebooks have been trained. - * @param topk the `k` in the select top-k. - * @param max_samples the size of the output for a single query. - * @param cluster_centers - * The device pointer to the cluster centers in the original space (NB: after rotation) - * [n_clusters, dim]. - * @param pq_centers - * The device pointer to the cluster centers in the PQ space - * [pq_dim, pq_book_size, pq_len] or [n_clusters, pq_book_size, pq_len,]. - * @param pq_dataset - * The device pointer to the PQ index (data) [n_rows, ...]. - * @param cluster_labels - * The device pointer to the labels (clusters) for each query and probe [n_queries, n_probes]. - * @param _chunk_indices - * The device pointer to the data offsets for each query and probe [n_queries, n_probes]. - * @param queries - * The device pointer to the queries (NB: after rotation) [n_queries, dim]. - * @param index_list - * An optional device pointer to the enforced order of search [n_queries, n_probes]. - * One can pass reordered indices here to try to improve data reading locality. - * @param lut_scores - * The device pointer for storing the lookup table globally [gridDim.x, pq_dim << PqBits]. - * Ignored when `EnableSMemLut == true`. - * @param _out_scores - * The device pointer to the output scores - * [n_queries, max_samples] or [n_queries, n_probes, topk]. - * @param _out_indices - * The device pointer to the output indices [n_queries, n_probes, topk]. - * These are the indices of the records as they appear in the database view formed by the probed - * clusters / defined by the `_chunk_indices`. - * The indices can have values within the range [0, max_samples). - * Ignored when `Capacity == 0`. - */ -template -__global__ void compute_similarity_kernel(uint32_t n_rows, - uint32_t dim, - uint32_t n_probes, - uint32_t pq_dim, - uint32_t n_queries, - distance::DistanceType metric, - codebook_gen codebook_kind, - uint32_t topk, - uint32_t max_samples, - const float* cluster_centers, - const float* pq_centers, - const uint8_t* const* pq_dataset, - const uint32_t* cluster_labels, - const uint32_t* _chunk_indices, - const float* queries, - const uint32_t* index_list, - float* query_kths, - LutT* lut_scores, - OutT* _out_scores, - uint32_t* _out_indices) -{ - /* Shared memory: - - * lut_scores: lookup table (LUT) of size = `pq_dim << PqBits` (when EnableSMemLut) - * base_diff: size = dim (which is equal to `pq_dim * pq_len`) or dim*2 - * topk::block_sort: some amount of shared memory, but overlaps with the rest: - block_sort only needs shared memory for `.done()` operation, which can come very last. - */ - extern __shared__ __align__(256) uint8_t smem_buf[]; // NOLINT - constexpr bool kManageLocalTopK = Capacity > 0; - - constexpr uint32_t PqShift = 1u << PqBits; // NOLINT - constexpr uint32_t PqMask = PqShift - 1u; // NOLINT - - const uint32_t pq_len = dim / pq_dim; - const uint32_t lut_size = pq_dim * PqShift; - - if constexpr (EnableSMemLut) { - lut_scores = reinterpret_cast(smem_buf); - } else { - lut_scores += lut_size * blockIdx.x; - } - - float* base_diff = nullptr; - if constexpr (PrecompBaseDiff) { - if constexpr (EnableSMemLut) { - base_diff = reinterpret_cast(lut_scores + lut_size); - } else { - base_diff = reinterpret_cast(smem_buf); - } - } - - for (int ib = blockIdx.x; ib < n_queries * n_probes; ib += gridDim.x) { - if (ib >= gridDim.x) { - // sync shared memory accesses on the second and further iterations - __syncthreads(); - } - uint32_t query_ix; - uint32_t probe_ix; - if (index_list == nullptr) { - query_ix = ib % n_queries; - probe_ix = ib / n_queries; - } else { - auto ordered_ix = index_list[ib]; - query_ix = ordered_ix / n_probes; - probe_ix = ordered_ix % n_probes; - } - - const uint32_t* chunk_indices = _chunk_indices + (n_probes * query_ix); - const float* query = queries + (dim * query_ix); - OutT* out_scores; - uint32_t* out_indices = nullptr; - if constexpr (kManageLocalTopK) { - // Store topk calculated distances to out_scores (and its indices to out_indices) - out_scores = _out_scores + topk * (probe_ix + (n_probes * query_ix)); - out_indices = _out_indices + topk * (probe_ix + (n_probes * query_ix)); - } else { - // Store all calculated distances to out_scores - out_scores = _out_scores + max_samples * query_ix; - } - uint32_t label = cluster_labels[n_probes * query_ix + probe_ix]; - const float* cluster_center = cluster_centers + (dim * label); - const float* pq_center; - if (codebook_kind == codebook_gen::PER_SUBSPACE) { - pq_center = pq_centers; - } else { - pq_center = pq_centers + (pq_len << PqBits) * label; - } - - if constexpr (PrecompBaseDiff) { - // Reduce number of memory reads later by pre-computing parts of the score - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - base_diff[i] = query[i] - cluster_center[i]; - } - } break; - case distance::DistanceType::InnerProduct: { - float2 pvals; - for (uint32_t i = threadIdx.x; i < dim; i += blockDim.x) { - pvals.x = query[i]; - pvals.y = cluster_center[i] * pvals.x; - reinterpret_cast(base_diff)[i] = pvals; - } - } break; - default: __builtin_unreachable(); - } - __syncthreads(); - } - - { - // Create a lookup table - // For each subspace, the lookup table stores the distance between the actual query vector - // (projected into the subspace) and all possible pq vectors in that subspace. - for (uint32_t i = threadIdx.x; i < lut_size; i += blockDim.x) { - const uint32_t i_pq = i >> PqBits; - uint32_t j = i_pq * pq_len; - const uint32_t j_end = pq_len + j; - auto cur_pq_center = pq_center + (i & PqMask) + - (codebook_kind == codebook_gen::PER_SUBSPACE ? j * PqShift : 0u); - float score = 0.0; - do { - float pq_c = *cur_pq_center; - cur_pq_center += PqShift; - switch (metric) { - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - float diff; - if constexpr (PrecompBaseDiff) { - diff = base_diff[j]; - } else { - diff = query[j] - cluster_center[j]; - } - diff -= pq_c; - score += diff * diff; - } break; - case distance::DistanceType::InnerProduct: { - // NB: we negate the scores as we hardcoded select-topk to always compute the minimum - float q; - if constexpr (PrecompBaseDiff) { - float2 pvals = reinterpret_cast(base_diff)[j]; - q = pvals.x; - score -= pvals.y; - } else { - q = query[j]; - score -= q * cluster_center[j]; - } - score -= q * pq_c; - } break; - default: __builtin_unreachable(); - } - } while (++j < j_end); - lut_scores[i] = LutT(score); - } - } - - // Define helper types for efficient access to the pq_dataset, which is stored in an interleaved - // format. The chunks of PQ data are stored in kIndexGroupVecLen-bytes-long chunks, interleaved - // in groups of kIndexGroupSize elems (which is normally equal to the warp size) for the fastest - // possible access by thread warps. - // - // Consider one record in the pq_dataset is `pq_dim * pq_bits`-bit-long. - // Assuming `kIndexGroupVecLen = 16`, one chunk of data read by a thread at once is 128-bits. - // Then, such a chunk contains `chunk_size = 128 / pq_bits` record elements, and the record - // consists of `ceildiv(pq_dim, chunk_size)` chunks. The chunks are interleaved in groups of 32, - // so that the warp can achieve the best coalesced read throughput. - using group_align = Pow2; - using vec_align = Pow2; - using local_topk_t = block_sort_t; - using op_t = uint32_t; - using vec_t = TxN_t; - - uint32_t sample_offset = 0; - if (probe_ix > 0) { sample_offset = chunk_indices[probe_ix - 1]; } - uint32_t n_samples = chunk_indices[probe_ix] - sample_offset; - uint32_t n_samples_aligned = group_align::roundUp(n_samples); - constexpr uint32_t kChunkSize = (kIndexGroupVecLen * 8u) / PqBits; - uint32_t pq_line_width = div_rounding_up_unsafe(pq_dim, kChunkSize) * kIndexGroupVecLen; - auto pq_thread_data = pq_dataset[label] + group_align::roundDown(threadIdx.x) * pq_line_width + - group_align::mod(threadIdx.x) * vec_align::Value; - pq_line_width *= blockDim.x; - - constexpr OutT kDummy = upper_bound(); - OutT query_kth = kDummy; - if constexpr (kManageLocalTopK) { query_kth = OutT(query_kths[query_ix]); } - local_topk_t block_topk(topk, nullptr, query_kth); - OutT early_stop_limit = kDummy; - switch (metric) { - // If the metric is non-negative, we can use the query_kth approximation as an early stop - // threshold to skip some iterations when computing the score. Add such metrics here. - case distance::DistanceType::L2SqrtExpanded: - case distance::DistanceType::L2Expanded: { - early_stop_limit = query_kth; - } break; - default: break; - } - - // Ensure lut_scores is written by all threads before using it in ivfpq-compute-score - __threadfence_block(); - __syncthreads(); - - // Compute a distance for each sample - for (uint32_t i = threadIdx.x; i < n_samples_aligned; - i += blockDim.x, pq_thread_data += pq_line_width) { - OutT score = kDummy; - bool valid = i < n_samples; - if (valid) { - score = ivfpq_compute_score( - pq_dim, - reinterpret_cast(pq_thread_data), - lut_scores, - early_stop_limit); - } - if constexpr (kManageLocalTopK) { - block_topk.add(score, sample_offset + i); - } else { - if (valid) { out_scores[sample_offset + i] = score; } - } - } - if constexpr (kManageLocalTopK) { - // sync threads before the topk merging operation, because we reuse smem_buf - __syncthreads(); - block_topk.done(smem_buf); - block_topk.store(out_scores, out_indices); - if (threadIdx.x == 0) { atomicMin(query_kths + query_ix, float(out_scores[topk - 1])); } - } else { - // fill in the rest of the out_scores with dummy values - if (probe_ix + 1 == n_probes) { - for (uint32_t i = threadIdx.x + sample_offset + n_samples; i < max_samples; - i += blockDim.x) { - out_scores[i] = kDummy; - } - } - } - } -} - -// The signature of the kernel defined by a minimal set of template parameters -template -using compute_similarity_kernel_t = - decltype(&compute_similarity_kernel); - -// The config struct lifts the runtime parameters to the template parameters -template -struct compute_similarity_kernel_config { - public: - static auto get(uint32_t pq_bits, uint32_t k_max) -> compute_similarity_kernel_t - { - return kernel_choose_bits(pq_bits, k_max); - } - - private: - static auto kernel_choose_bits(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t - { - switch (pq_bits) { - case 4: return kernel_try_capacity<4, kMaxCapacity>(k_max); - case 5: return kernel_try_capacity<5, kMaxCapacity>(k_max); - case 6: return kernel_try_capacity<6, kMaxCapacity>(k_max); - case 7: return kernel_try_capacity<7, kMaxCapacity>(k_max); - case 8: return kernel_try_capacity<8, kMaxCapacity>(k_max); - default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits); - } - } - - template - static auto kernel_try_capacity(uint32_t k_max) -> compute_similarity_kernel_t - { - if constexpr (Capacity > 0) { - if (k_max == 0 || k_max > Capacity) { return kernel_try_capacity(k_max); } - } - if constexpr (Capacity > 1) { - if (k_max * 2 <= Capacity) { return kernel_try_capacity(k_max); } - } - return compute_similarity_kernel; - } -}; - -// A standalone accessor function is necessary to make sure template specializations work correctly -// (we "extern template" this function) -template -auto get_compute_similarity_kernel(uint32_t pq_bits, uint32_t k_max) - -> compute_similarity_kernel_t -{ - return compute_similarity_kernel_config::get(pq_bits, - k_max); -} - /** * An approximation to the number of times each cluster appears in a batched sample. * @@ -930,318 +404,6 @@ constexpr inline auto expected_probe_coresidency(uint32_t n_clusters, return 1 + (n_queries - 1) * n_probes / (2 * n_clusters); } -/** - * Estimate a carveout value as expected by `cudaFuncAttributePreferredSharedMemoryCarveout` - * (which does not take into account `reservedSharedMemPerBlock`), - * given by a desired schmem-L1 split and a per-block memory requirement in bytes. - * - * NB: As per the programming guide, the memory carveout setting is just a hint for the driver; it's - * free to choose any shmem-L1 configuration it deems appropriate. For example, if you set the - * carveout to zero, it will choose a non-zero config that will allow to run at least one active - * block per SM. - * - * @param shmem_fraction - * a fraction representing a desired split (shmem / (shmem + L1)) [0, 1]. - * @param shmem_per_block - * a shared memory usage per block (dynamic + static shared memory sizes), in bytes. - * @param dev_props - * device properties. - * @return - * a carveout value in percents [0, 100]. - */ -constexpr inline auto estimate_carveout(double shmem_fraction, - size_t shmem_per_block, - const cudaDeviceProp& dev_props) -> int -{ - using shmem_unit = Pow2<128>; - size_t m = shmem_unit::roundUp(shmem_per_block); - size_t r = dev_props.reservedSharedMemPerBlock; - size_t s = dev_props.sharedMemPerMultiprocessor; - return (size_t(100 * s * m * shmem_fraction) - (m - 1) * r) / (s * (m + r)); -} - -/** Select an appropriate kernel instance and launch parameters. */ -template -struct compute_similarity { - /** Estimate the occupancy for the given kernel on the given device. */ - struct occupancy_t { - using shmem_unit = Pow2<128>; - - int blocks_per_sm = 0; - double occupancy = 0.0; - double shmem_use = 1.0; - - inline occupancy_t() = default; - inline occupancy_t(size_t smem, - uint32_t n_threads, - compute_similarity_kernel_t kernel, - const cudaDeviceProp& dev_props) - { - RAFT_CUDA_TRY( - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&blocks_per_sm, kernel, n_threads, smem)); - occupancy = double(blocks_per_sm * n_threads) / double(dev_props.maxThreadsPerMultiProcessor); - shmem_use = double(shmem_unit::roundUp(smem) * blocks_per_sm) / - double(dev_props.sharedMemPerMultiprocessor); - } - }; - - struct selected { - compute_similarity_kernel_t kernel; - dim3 grid_dim; - dim3 block_dim; - size_t smem_size; - size_t device_lut_size; - - template - void operator()(rmm::cuda_stream_view stream, Args... args) - { - kernel<<>>(args...); - RAFT_CHECK_CUDA(stream); - } - }; - - /** - * Use heuristics to choose an optimal instance of the search kernel. - * It selects among a few kernel variants (with/out using shared mem for - * lookup tables / precomputed distances) and tries to choose the block size - * to maximize kernel occupancy. - * - * @param manage_local_topk - * whether use the fused calculate+select or just calculate the distances for each - * query and probed cluster. - * - * @param locality_hint - * beyond this limit do not consider increasing the number of active blocks per SM - * would improve locality anymore. - */ - static inline auto select(const cudaDeviceProp& dev_props, - bool manage_local_topk, - int locality_hint, - double preferred_shmem_carveout, - uint32_t pq_bits, - uint32_t pq_dim, - uint32_t precomp_data_count, - uint32_t n_queries, - uint32_t n_probes, - uint32_t topk) -> selected - { - // Shared memory for storing the lookup table - size_t lut_mem = sizeof(LutT) * (pq_dim << pq_bits); - // Shared memory for storing pre-computed pieces to speedup the lookup table construction - // (e.g. the distance between a cluster center and the query for L2). - size_t bdf_mem = sizeof(float) * precomp_data_count; - // Shared memory for the fused top-k component; it may overlap with the other uses of shared - // memory and depends on the number of threads. - struct ltk_mem_t { - uint32_t subwarp_size; - uint32_t topk; - bool manage_local_topk; - ltk_mem_t(bool manage_local_topk, uint32_t topk) - : manage_local_topk(manage_local_topk), topk(topk) - { - subwarp_size = WarpSize; - while (topk * 2 <= subwarp_size) { - subwarp_size /= 2; - } - } - - [[nodiscard]] auto operator()(uint32_t n_threads) const -> size_t - { - return manage_local_topk ? matrix::detail::select::warpsort:: - template calc_smem_size_for_block_wide( - n_threads / subwarp_size, topk) - : 0; - } - } ltk_mem{manage_local_topk, topk}; - - // Total amount of work; should be enough to occupy the GPU. - uint32_t n_blocks = n_queries * n_probes; - - // The minimum block size we may want: - // 1. It's a power-of-two for efficient L1 caching of pq_centers values - // (multiples of `1 << pq_bits`). - // 2. It should be large enough to fully utilize an SM. - uint32_t n_threads_min = WarpSize; - while (dev_props.maxBlocksPerMultiProcessor * int(n_threads_min) < - dev_props.maxThreadsPerMultiProcessor) { - n_threads_min *= 2; - } - // Further increase the minimum block size to make sure full device occupancy - // (NB: this may lead to `n_threads_min` being larger than the kernel's maximum) - while (int(n_blocks * n_threads_min) < - dev_props.multiProcessorCount * dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - // Even further, increase it to allow less blocks per SM if there not enough queries. - // With this, we reduce the chance of different clusters being processed by two blocks - // on the same SM and thus improve the data locality for L1 caching. - while (int(n_queries * n_threads_min) < dev_props.maxThreadsPerMultiProcessor && - int(n_threads_min) < dev_props.maxThreadsPerBlock) { - n_threads_min *= 2; - } - - // Granularity of changing the number of threads when computing the maximum block size. - // It's good to have it multiple of the PQ book width. - uint32_t n_threads_gty = round_up_safe(1u << pq_bits, WarpSize); - - /* - Shared memory / L1 cache balance is the main limiter of this kernel. - The more blocks per SM we launch, the more shared memory we need. Besides that, we have - three versions of the kernel varying in performance and shmem usage. - - We try the most demanding and the fastest kernel first, trying to maximize occupancy with - the minimum number of blocks (just one, really). Then, we tweak the `n_threads` to further - optimize occupancy and data locality for the L1 cache. - */ - auto conf_fast = get_compute_similarity_kernel; - auto conf_no_basediff = get_compute_similarity_kernel; - auto conf_no_smem_lut = get_compute_similarity_kernel; - auto topk_or_zero = manage_local_topk ? topk : 0u; - std::array candidates{ - std::make_tuple(conf_fast(pq_bits, topk_or_zero), lut_mem + bdf_mem, true), - std::make_tuple(conf_no_basediff(pq_bits, topk_or_zero), lut_mem, true), - std::make_tuple(conf_no_smem_lut(pq_bits, topk_or_zero), bdf_mem, false)}; - - // we may allow slightly lower than 100% occupancy; - constexpr double kTargetOccupancy = 0.75; - // This struct is used to select the better candidate - occupancy_t selected_perf{}; - selected selected_config; - for (auto [kernel, smem_size_const, lut_is_in_shmem] : candidates) { - if (smem_size_const > dev_props.sharedMemPerBlockOptin) { - // Even a single block cannot fit into an SM due to shmem requirements. Skip the candidate. - continue; - } - - // First, we set the carveout hint to the preferred value. The driver will increase this if - // needed to run at least one block per SM. At the same time, if more blocks fit into one SM, - // this carveout value will limit the calculated occupancy. When we're done selecting the best - // launch configuration, we will tighten the carveout once more, based on the final memory - // usage and occupancy. - const int max_carveout = - estimate_carveout(preferred_shmem_carveout, smem_size_const, dev_props); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, max_carveout)); - - // Get the theoretical maximum possible number of threads per block - cudaFuncAttributes kernel_attrs; - RAFT_CUDA_TRY(cudaFuncGetAttributes(&kernel_attrs, kernel)); - uint32_t n_threads = - round_down_safe(kernel_attrs.maxThreadsPerBlock, n_threads_gty); - - // Actual required shmem depens on the number of threads - size_t smem_size = max(smem_size_const, ltk_mem(n_threads)); - - // Make sure the kernel can get enough shmem. - cudaError_t cuda_status = - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); - if (cuda_status != cudaSuccess) { - RAFT_EXPECTS( - cuda_status == cudaGetLastError(), - "Tried to reset the expected cuda error code, but it didn't match the expectation"); - // Failed to request enough shmem for the kernel. Skip the candidate. - continue; - } - - occupancy_t cur(smem_size, n_threads, kernel, dev_props); - if (cur.blocks_per_sm <= 0) { - // For some reason, we still cannot make this kernel run. Skip the candidate. - continue; - } - - { - // Try to reduce the number of threads to increase occupancy and data locality - auto n_threads_tmp = n_threads_min; - while (n_threads_tmp * 2 < n_threads) { - n_threads_tmp *= 2; - } - if (n_threads_tmp < n_threads) { - while (n_threads_tmp >= n_threads_min) { - auto smem_size_tmp = max(smem_size_const, ltk_mem(n_threads_tmp)); - occupancy_t tmp(smem_size_tmp, n_threads_tmp, kernel, dev_props); - bool select_it = false; - if (lut_is_in_shmem && locality_hint >= tmp.blocks_per_sm) { - // Normally, the smaller the block the better for L1 cache hit rate. - // Hence, the occupancy should be "just good enough" - select_it = tmp.occupancy >= min(kTargetOccupancy, cur.occupancy); - } else if (lut_is_in_shmem) { - // If we don't have enough repeating probes (locality_hint < tmp.blocks_per_sm), - // the locality is not going to improve with increasing the number of blocks per SM. - // Hence, the only metric here is the occupancy. - bool improves_occupancy = tmp.occupancy > cur.occupancy; - // Otherwise, the performance still improves with a smaller block size, - // given there is enough work to do - bool improves_parallelism = - tmp.occupancy == cur.occupancy && - 7u * tmp.blocks_per_sm * dev_props.multiProcessorCount <= n_blocks; - select_it = improves_occupancy || improves_parallelism; - } else { - // If we don't use shared memory for the lookup table, increasing the number of blocks - // is very taxing on the global memory usage. - // In this case, the occupancy must increase a lot to make it worth the cost. - select_it = tmp.occupancy >= min(1.0, cur.occupancy / kTargetOccupancy); - } - if (select_it) { - n_threads = n_threads_tmp; - smem_size = smem_size_tmp; - cur = tmp; - } - n_threads_tmp /= 2; - } - } - } - - { - if (selected_perf.occupancy <= 0.0 // no candidate yet - || (selected_perf.occupancy < cur.occupancy * kTargetOccupancy && - selected_perf.shmem_use >= cur.shmem_use) // much improved occupancy - ) { - selected_perf = cur; - if (lut_is_in_shmem) { - selected_config = { - kernel, dim3(n_blocks, 1, 1), dim3(n_threads, 1, 1), smem_size, size_t(0)}; - } else { - // When the global memory is used for the lookup table, we need to minimize the grid - // size; otherwise, the kernel may quickly run out of memory. - auto n_blocks_min = - std::min(n_blocks, cur.blocks_per_sm * dev_props.multiProcessorCount); - selected_config = {kernel, - dim3(n_blocks_min, 1, 1), - dim3(n_threads, 1, 1), - smem_size, - size_t(n_blocks_min) * size_t(pq_dim << pq_bits)}; - } - // Actual shmem/L1 split wildly rounds up the specified preferred carveout, so we set here - // a rather conservative bar; most likely, the kernel gets more shared memory than this, - // and the occupancy doesn't get hurt. - auto carveout = std::min(max_carveout, std::ceil(100.0 * cur.shmem_use)); - RAFT_CUDA_TRY( - cudaFuncSetAttribute(kernel, cudaFuncAttributePreferredSharedMemoryCarveout, carveout)); - if (cur.occupancy >= kTargetOccupancy) { break; } - } else if (selected_perf.occupancy > 0.0) { - // If we found a reasonable candidate on a previous iteration, and this one is not better, - // then don't try any more candidates because they are much slower anyway. - break; - } - } - } - - RAFT_EXPECTS(selected_perf.occupancy > 0.0, - "Couldn't determine a working kernel launch configuration."); - - return selected_config; - } -}; - -inline auto is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool -{ - if (k > kMaxCapacity) { return false; } // warp_sort not possible - if (n_probes <= 16) { return false; } // too few clusters - if (n_queries * n_probes <= 256) { return false; } // overall amount of work is too small - return true; -} - /** * The "main part" of the search, which assumes that outer-level `search` has already: * @@ -1364,16 +526,16 @@ void ivfpq_search_worker(raft::device_resources const& handle, } break; } - auto search_instance = compute_similarity::select(handle.get_device_properties(), - manage_local_topk, - coresidency, - preferred_shmem_carveout, - index.pq_bits(), - index.pq_dim(), - precomp_data_count, - n_queries, - n_probes, - topK); + auto search_instance = compute_similarity_select(handle.get_device_properties(), + manage_local_topk, + coresidency, + preferred_shmem_carveout, + index.pq_bits(), + index.pq_dim(), + precomp_data_count, + n_queries, + n_probes, + topK); rmm::device_uvector device_lut(search_instance.device_lut_size, stream, mr); std::optional> query_kths_buf{std::nullopt}; @@ -1386,27 +548,28 @@ void ivfpq_search_worker(raft::device_resources const& handle, raft::const_op{dummy_block_sort_t::queue_t::kDummy}); query_kths = query_kths_buf->data_handle(); } - search_instance(stream, - index.size(), - index.rot_dim(), - n_probes, - index.pq_dim(), - n_queries, - index.metric(), - index.codebook_kind(), - topK, - max_samples, - index.centers_rot().data_handle(), - index.pq_centers().data_handle(), - index.data_ptrs().data_handle(), - clusters_to_probe, - chunk_index.data(), - query, - index_list_sorted, - query_kths, - device_lut.data(), - distances_buf.data(), - neighbors_ptr); + compute_similarity_run(search_instance, + stream, + index.size(), + index.rot_dim(), + n_probes, + index.pq_dim(), + n_queries, + index.metric(), + index.codebook_kind(), + topK, + max_samples, + index.centers_rot().data_handle(), + index.pq_centers().data_handle(), + index.data_ptrs().data_handle(), + clusters_to_probe, + chunk_index.data(), + query, + index_list_sorted, + query_kths, + device_lut.data(), + distances_buf.data(), + neighbors_ptr); // Select topk vectors for each query rmm::device_uvector topk_dists(n_queries * topK, stream, mr); diff --git a/cpp/include/raft/neighbors/detail/refine.cuh b/cpp/include/raft/neighbors/detail/refine.cuh index b85bbd0e9c..0ff5e4cdbc 100644 --- a/cpp/include/raft/neighbors/detail/refine.cuh +++ b/cpp/include/raft/neighbors/detail/refine.cuh @@ -22,6 +22,7 @@ #include #include #include +#include #include #include diff --git a/cpp/include/raft/neighbors/ivf_flat-inl.cuh b/cpp/include/raft/neighbors/ivf_flat-inl.cuh index f12062f851..4f8d7f596e 100644 --- a/cpp/include/raft/neighbors/ivf_flat-inl.cuh +++ b/cpp/include/raft/neighbors/ivf_flat-inl.cuh @@ -97,8 +97,8 @@ auto build(raft::device_resources const& handle, * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle * @param[in] params configure the index building @@ -106,17 +106,16 @@ auto build(raft::device_resources const& handle, * * @return the constructed ivf-flat index */ -template +template auto build(raft::device_resources const& handle, const index_params& params, - raft::device_matrix_view dataset) - -> index + raft::device_matrix_view dataset) -> index { return raft::neighbors::ivf_flat::detail::build(handle, params, dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); } /** @@ -141,8 +140,8 @@ auto build(raft::device_resources const& handle, * ivf_flat::search(handle, search_params, index, queries, out_inds, out_dists); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle * @param[in] params configure the index building @@ -150,17 +149,17 @@ auto build(raft::device_resources const& handle, * @param[out] idx reference to ivf_flat::index * */ -template +template void build(raft::device_resources const& handle, const index_params& params, - raft::device_matrix_view dataset, - raft::neighbors::ivf_flat::index& idx) + raft::device_matrix_view dataset, + raft::neighbors::ivf_flat::index& idx) { idx = raft::neighbors::ivf_flat::detail::build(handle, params, dataset.data_handle(), - static_cast(dataset.extent(0)), - static_cast(dataset.extent(1))); + static_cast(dataset.extent(0)), + static_cast(dataset.extent(1))); } /** @} */ @@ -229,12 +228,12 @@ auto extend(raft::device_resources const& handle, * // train the index from a [N, D] dataset * auto index_empty = ivf_flat::build(handle, dataset, index_params, dataset); * // fill the index with the data - * std::optional> no_op = std::nullopt; + * std::optional> no_op = std::nullopt; * auto index = ivf_flat::extend(handle, index_empty, no_op, dataset); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] @@ -245,18 +244,17 @@ auto extend(raft::device_resources const& handle, * * @return the constructed extended ivf-flat index */ -template +template auto extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - const index& orig_index) -> index + raft::device_matrix_view new_vectors, + std::optional> new_indices, + const index& orig_index) -> index { - return extend( - handle, - orig_index, - new_vectors.data_handle(), - new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - new_vectors.extent(0)); + return extend(handle, + orig_index, + new_vectors.data_handle(), + new_indices.has_value() ? new_indices.value().data_handle() : nullptr, + new_vectors.extent(0)); } /** @} */ @@ -314,12 +312,12 @@ void extend(raft::device_resources const& handle, * // train the index from a [N, D] dataset * auto index_empty = ivf_flat::build(handle, index_params, dataset); * // fill the index with the data - * std::optional> no_op = std::nullopt; + * std::optional> no_op = std::nullopt; * ivf_flat::extend(handle, dataset, no_opt, &index_empty); * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices in the source dataset + * @tparam T data element type + * @tparam IdxT type of the indices in the source dataset * * @param[in] handle * @param[in] new_vectors raft::device_matrix_view to a row-major matrix [n_rows, index.dim()] @@ -328,17 +326,17 @@ void extend(raft::device_resources const& handle, * here to imply a continuous range `[0...n_rows)`. * @param[inout] index pointer to index, to be overwritten in-place */ -template +template void extend(raft::device_resources const& handle, - raft::device_matrix_view new_vectors, - std::optional> new_indices, - index* index) + raft::device_matrix_view new_vectors, + std::optional> new_indices, + index* index) { extend(handle, index, new_vectors.data_handle(), new_indices.has_value() ? new_indices.value().data_handle() : nullptr, - static_cast(new_vectors.extent(0))); + static_cast(new_vectors.extent(0))); } /** @} */ @@ -426,8 +424,8 @@ void search(raft::device_resources const& handle, * ... * @endcode * - * @tparam value_t data element type - * @tparam idx_t type of the indices + * @tparam T data element type + * @tparam IdxT type of the indices * * @param[in] handle * @param[in] params configure the search @@ -437,13 +435,13 @@ void search(raft::device_resources const& handle, * [n_queries, k] * @param[out] distances a device pointer to the distances to the selected neighbors [n_queries, k] */ -template +template void search(raft::device_resources const& handle, const search_params& params, - const index& index, - raft::device_matrix_view queries, - raft::device_matrix_view neighbors, - raft::device_matrix_view distances) + const index& index, + raft::device_matrix_view queries, + raft::device_matrix_view neighbors, + raft::device_matrix_view distances) { RAFT_EXPECTS( queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0), diff --git a/cpp/include/raft/neighbors/ivf_flat_types.hpp b/cpp/include/raft/neighbors/ivf_flat_types.hpp index 2a6aa12847..e9d8111f47 100644 --- a/cpp/include/raft/neighbors/ivf_flat_types.hpp +++ b/cpp/include/raft/neighbors/ivf_flat_types.hpp @@ -27,6 +27,7 @@ #include #include +#include // std::max #include #include #include @@ -379,10 +380,11 @@ struct index : ann::index { { // TODO: consider padding the dimensions and fixing veclen to its maximum possible value as a // template parameter (https://github.com/rapidsai/raft/issues/711) - uint32_t veclen = 16 / sizeof(T); - while (dim % veclen != 0) { - veclen = veclen >> 1; - } + + // NOTE: keep this consistent with the select_interleaved_scan_kernel logic + // in detail/ivf_flat_interleaved_scan-inl.cuh. + uint32_t veclen = std::max(1, 16 / sizeof(T)); + if (dim % veclen != 0) { veclen = 1; } return veclen; } }; diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu index 2c34d50a8c..4dfa2a707c 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_float_float_int64_t.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu index 77aea885fb..2d54248e4d 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_int8_t_int32_t_int64_t.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ diff --git a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu index 57d09b7d52..75fe52f3c7 100644 --- a/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu +++ b/cpp/src/neighbors/detail/ivf_flat_interleaved_scan_uint8_t_uint32_t_int64_t.cu @@ -14,7 +14,7 @@ * limitations under the License. */ -#include +#include #define instantiate_raft_neighbors_ivf_flat_detail_ivfflat_interleaved_scan(T, AccT, IdxT) \ template void raft::neighbors::ivf_flat::detail::ivfflat_interleaved_scan( \ diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py new file mode 100644 index 0000000000..a740d01bd2 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_00_generate.py @@ -0,0 +1,107 @@ +# Copyright (c) 2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +header = """ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \\ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \\ + const cudaDeviceProp& dev_props, \\ + bool manage_local_topk, \\ + int locality_hint, \\ + double preferred_shmem_carveout, \\ + uint32_t pq_bits, \\ + uint32_t pq_dim, \\ + uint32_t precomp_data_count, \\ + uint32_t n_queries, \\ + uint32_t n_probes, \\ + uint32_t topk) -> raft::neighbors::ivf_pq::detail::selected; \\ +\\ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \\ + raft::neighbors::ivf_pq::detail::selected s, \\ + rmm::cuda_stream_view stream, \\ + uint32_t n_rows, \\ + uint32_t dim, \\ + uint32_t n_probes, \\ + uint32_t pq_dim, \\ + uint32_t n_queries, \\ + raft::distance::DistanceType metric, \\ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \\ + uint32_t topk, \\ + uint32_t max_samples, \\ + const float* cluster_centers, \\ + const float* pq_centers, \\ + const uint8_t* const* pq_dataset, \\ + const uint32_t* cluster_labels, \\ + const uint32_t* _chunk_indices, \\ + const float* queries, \\ + const uint32_t* index_list, \\ + float* query_kths, \\ + LutT* lut_scores, \\ + OutT* _out_scores, \\ + uint32_t* _out_indices); + + +#define COMMA , +""" + +trailer = """ +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select +""" + +types = dict( + half_fp8_false=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), + half_fp8_true=("half", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), + half_half=("half", "half"), + float_half=("float", "half"), + float_float= ("float", "float"), + float_fp8_false=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>"), + float_fp8_true=("float", "raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>"), +) + +for path_key, (OutT, LutT) in types.items(): + path = f"ivf_pq_compute_similarity_{path_key}.cu" + with open(path, "w") as f: + f.write(header) + f.write(f"instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select({OutT}, {LutT});\n") + f.write(trailer) + print(f"src/neighbors/detail/{path}") diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu new file mode 100644 index 0000000000..956b7010d5 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_float.cu @@ -0,0 +1,73 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, float); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu new file mode 100644 index 0000000000..fba72ad1dd --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_false.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu new file mode 100644 index 0000000000..030f429315 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_fp8_true.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + float, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu new file mode 100644 index 0000000000..31a4d7d503 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_float_half.cu @@ -0,0 +1,73 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(float, half); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu new file mode 100644 index 0000000000..c623c80446 --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_false.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA false>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu new file mode 100644 index 0000000000..f2aaca20db --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_fp8_true.cu @@ -0,0 +1,74 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select( + half, raft::neighbors::ivf_pq::detail::fp_8bit<5u COMMA true>); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu new file mode 100644 index 0000000000..4420b2534b --- /dev/null +++ b/cpp/src/neighbors/detail/ivf_pq_compute_similarity_half_half.cu @@ -0,0 +1,73 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by ivf_pq_compute_similarity_00_generate.py + * + * Make changes there and run in this directory: + * + * > python ivf_pq_compute_similarity_00_generate.py + * + */ + +#include +#include + +#define instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(OutT, LutT) \ + template auto raft::neighbors::ivf_pq::detail::compute_similarity_select( \ + const cudaDeviceProp& dev_props, \ + bool manage_local_topk, \ + int locality_hint, \ + double preferred_shmem_carveout, \ + uint32_t pq_bits, \ + uint32_t pq_dim, \ + uint32_t precomp_data_count, \ + uint32_t n_queries, \ + uint32_t n_probes, \ + uint32_t topk) \ + ->raft::neighbors::ivf_pq::detail::selected; \ + \ + template void raft::neighbors::ivf_pq::detail::compute_similarity_run( \ + raft::neighbors::ivf_pq::detail::selected s, \ + rmm::cuda_stream_view stream, \ + uint32_t n_rows, \ + uint32_t dim, \ + uint32_t n_probes, \ + uint32_t pq_dim, \ + uint32_t n_queries, \ + raft::distance::DistanceType metric, \ + raft::neighbors::ivf_pq::codebook_gen codebook_kind, \ + uint32_t topk, \ + uint32_t max_samples, \ + const float* cluster_centers, \ + const float* pq_centers, \ + const uint8_t* const* pq_dataset, \ + const uint32_t* cluster_labels, \ + const uint32_t* _chunk_indices, \ + const float* queries, \ + const uint32_t* index_list, \ + float* query_kths, \ + LutT* lut_scores, \ + OutT* _out_scores, \ + uint32_t* _out_indices); + +#define COMMA , +instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select(half, half); + +#undef COMMA + +#undef instantiate_raft_neighbors_ivf_pq_detail_compute_similarity_select diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py index 4e719e272c..f6631144f4 100644 --- a/cpp/test/ext_headers/00_generate.py +++ b/cpp/test/ext_headers/00_generate.py @@ -55,6 +55,8 @@ "raft/neighbors/detail/selection_faiss-ext.cuh", "raft/linalg/detail/coalesced_reduction-ext.cuh", "raft/spatial/knn/detail/ball_cover/registers-ext.cuh", + "raft/neighbors/detail/ivf_flat_interleaved_scan-ext.cuh", + "raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh", ] for ext_header in ext_headers: diff --git a/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu new file mode 100644 index 0000000000..5a3a0b3f76 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu b/cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu new file mode 100644 index 0000000000..fd5ad62204 --- /dev/null +++ b/cpp/test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include From 6109c12b5d92e7db4e9d3e5566622d8f02ec3bb9 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 20 Apr 2023 20:47:34 +0200 Subject: [PATCH 28/32] Isolate logger and memory pool These two include files are likely to transitively include spdlog, which increases compilation times. --- cpp/CMakeLists.txt | 2 + cpp/include/raft/core/logger-ext.hpp | 128 ++++++++ cpp/include/raft/core/logger-inl.hpp | 156 +++++++++ cpp/include/raft/core/logger-macros.hpp | 106 ++++++ cpp/include/raft/core/logger.hpp | 310 +----------------- cpp/include/raft/util/cudart_utils.hpp | 52 +-- cpp/include/raft/util/inline.hpp | 23 ++ cpp/include/raft/util/memory_pool-ext.hpp | 27 ++ cpp/include/raft/util/memory_pool-inl.hpp | 76 +++++ cpp/include/raft/util/memory_pool.hpp | 23 ++ cpp/src/core/logger.cpp | 16 + cpp/src/util/memory_pool.cpp | 17 + cpp/test/CMakeLists.txt | 4 + cpp/test/ext_headers/00_generate.py | 2 + cpp/test/ext_headers/raft_core_logger.cpp | 27 ++ .../ext_headers/raft_util_memory_pool.cpp | 27 ++ 16 files changed, 641 insertions(+), 355 deletions(-) create mode 100644 cpp/include/raft/core/logger-ext.hpp create mode 100644 cpp/include/raft/core/logger-inl.hpp create mode 100644 cpp/include/raft/core/logger-macros.hpp create mode 100644 cpp/include/raft/util/inline.hpp create mode 100644 cpp/include/raft/util/memory_pool-ext.hpp create mode 100644 cpp/include/raft/util/memory_pool-inl.hpp create mode 100644 cpp/include/raft/util/memory_pool.hpp create mode 100644 cpp/src/core/logger.cpp create mode 100644 cpp/src/util/memory_pool.cpp create mode 100644 cpp/test/ext_headers/raft_core_logger.cpp create mode 100644 cpp/test/ext_headers/raft_util_memory_pool.cpp diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index dfed15f6a4..cddfa4b38d 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -263,6 +263,7 @@ set_target_properties(raft_compiled PROPERTIES EXPORT_NAME compiled) if(RAFT_COMPILE_LIBRARY) add_library( raft_lib + src/core/logger.cpp src/distance/detail/pairwise_matrix/dispatch_canberra_double_double_double_int.cu src/distance/detail/pairwise_matrix/dispatch_canberra_float_float_float_int.cu src/distance/detail/pairwise_matrix/dispatch_correlation_double_double_double_int.cu @@ -390,6 +391,7 @@ if(RAFT_COMPILE_LIBRARY) src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu + src/util/memory_pool.cpp ) set_target_properties( raft_lib diff --git a/cpp/include/raft/core/logger-ext.hpp b/cpp/include/raft/core/logger-ext.hpp new file mode 100644 index 0000000000..69688560c7 --- /dev/null +++ b/cpp/include/raft/core/logger-ext.hpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include // std::unique_ptr +#include // std::string +#include // std::unordered_map + +namespace raft { + +static const std::string RAFT_NAME = "raft"; +static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); + +/** + * @brief The main Logging class for raft library. + * + * This class acts as a thin wrapper over the underlying `spdlog` interface. The + * design is done in this way in order to avoid us having to also ship `spdlog` + * header files in our installation. + * + * @todo This currently only supports logging to stdout. Need to add support in + * future to add custom loggers as well [Issue #2046] + */ +class logger { + public: + // @todo setting the logger once per process with + logger(std::string const& name_ = ""); + /** + * @brief Singleton method to get the underlying logger object + * + * @return the singleton logger object + */ + static logger& get(std::string const& name = ""); + + /** + * @brief Set the logging level. + * + * Only messages with level equal or above this will be printed + * + * @param[in] level logging level + * + * @note The log level will actually be set only if the input is within the + * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll + * be ignored. See documentation of decisiontree for how this gets used + */ + void set_level(int level); + + /** + * @brief Set the logging pattern + * + * @param[in] pattern the pattern to be set. Refer this link + * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting + * to know the right syntax of this pattern + */ + void set_pattern(const std::string& pattern); + + /** + * @brief Register a callback function to be run in place of usual log call + * + * @param[in] callback the function to be run on all logged messages + */ + void set_callback(void (*callback)(int lvl, const char* msg)); + + /** + * @brief Register a flush function compatible with the registered callback + * + * @param[in] flush the function to use when flushing logs + */ + void set_flush(void (*flush)()); + + /** + * @brief Tells whether messages will be logged for the given log level + * + * @param[in] level log level to be checked for + * @return true if messages will be logged for this level, else false + */ + bool should_log_for(int level) const; + /** + * @brief Query for the current log level + * + * @return the current log level + */ + int get_level() const; + + /** + * @brief Get the current logging pattern + * @return the pattern + */ + std::string get_pattern() const; + + /** + * @brief Main logging method + * + * @param[in] level logging level of this message + * @param[in] fmt C-like format string, followed by respective params + */ + void log(int level, const char* fmt, ...); + + /** + * @brief Flush logs by calling flush on underlying logger + */ + void flush(); + + ~logger(); + + private: + logger(); + // pimpl pattern: + // https://learn.microsoft.com/en-us/cpp/cpp/pimpl-for-compile-time-encapsulation-modern-cpp?view=msvc-170 + class impl; + std::unique_ptr pimpl; + static inline std::unordered_map> log_map; +}; // class logger + +}; // namespace raft diff --git a/cpp/include/raft/core/logger-inl.hpp b/cpp/include/raft/core/logger-inl.hpp new file mode 100644 index 0000000000..85ea4baea5 --- /dev/null +++ b/cpp/include/raft/core/logger-inl.hpp @@ -0,0 +1,156 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include + +#include + +#include +#include +#include +#include +#include + +#include + +#include "logger-macros.hpp" +// The logger-ext.hpp file contains the class declaration of the logger class. +// In this case, it is okay to include the logger-ext.hpp file because it +// contains no RAFT_EXPLICIT template instantiations. +#include "logger-ext.hpp" + +#define SPDLOG_HEADER_ONLY +#include +#include +#include // RAFT_INLINE_CONDITIONAL +#include // NOLINT +#include // NOLINT + +namespace raft { + +namespace detail { + +inline std::string format(const char* fmt, va_list& vl) +{ + va_list vl_copy; + va_copy(vl_copy, vl); + int length = std::vsnprintf(nullptr, 0, fmt, vl_copy); + assert(length >= 0); + std::vector buf(length + 1); + std::vsnprintf(buf.data(), length + 1, fmt, vl); + return std::string(buf.data()); +} + +inline std::string format(const char* fmt, ...) +{ + va_list vl; + va_start(vl, fmt); + std::string str = format(fmt, vl); + va_end(vl); + return str; +} + +inline int convert_level_to_spdlog(int level) +{ + level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); + return RAFT_LEVEL_TRACE - level; +} + +} // namespace detail + +class logger::impl { // defined privately here + // ... all private data and functions: all of these + // can now change without recompiling callers ... + public: + std::shared_ptr sink; + std::shared_ptr spdlogger; + std::string cur_pattern; + int cur_level; + + impl(std::string const& name_ = "") + : sink{std::make_shared()}, + spdlogger{std::make_shared(name_, sink)}, + cur_pattern() + { + } +}; // class logger::impl + +RAFT_INLINE_CONDITIONAL logger::logger(std::string const& name_) : pimpl(new impl(name_)) +{ + set_pattern(default_log_pattern); + set_level(RAFT_ACTIVE_LEVEL); +} + +RAFT_INLINE_CONDITIONAL logger& logger::get(std::string const& name) +{ + if (log_map.find(name) == log_map.end()) { log_map[name] = std::make_shared(name); } + return *log_map[name]; +} + +RAFT_INLINE_CONDITIONAL void logger::set_level(int level) +{ + level = raft::detail::convert_level_to_spdlog(level); + pimpl->spdlogger->set_level(static_cast(level)); +} + +RAFT_INLINE_CONDITIONAL void logger::set_pattern(const std::string& pattern) +{ + pimpl->cur_pattern = pattern; + pimpl->spdlogger->set_pattern(pattern); +} + +RAFT_INLINE_CONDITIONAL void logger::set_callback(void (*callback)(int lvl, const char* msg)) +{ + pimpl->sink->set_callback(callback); +} + +RAFT_INLINE_CONDITIONAL void logger::set_flush(void (*flush)()) { pimpl->sink->set_flush(flush); } + +RAFT_INLINE_CONDITIONAL bool logger::should_log_for(int level) const +{ + level = raft::detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + return pimpl->spdlogger->should_log(level_e); +} + +RAFT_INLINE_CONDITIONAL int logger::get_level() const +{ + auto level_e = pimpl->spdlogger->level(); + return RAFT_LEVEL_TRACE - static_cast(level_e); +} + +RAFT_INLINE_CONDITIONAL std::string logger::get_pattern() const { return pimpl->cur_pattern; } + +RAFT_INLINE_CONDITIONAL void logger::log(int level, const char* fmt, ...) +{ + level = raft::detail::convert_level_to_spdlog(level); + auto level_e = static_cast(level); + // explicit check to make sure that we only expand messages when required + if (pimpl->spdlogger->should_log(level_e)) { + va_list vl; + va_start(vl, fmt); + auto msg = raft::detail::format(fmt, vl); + va_end(vl); + pimpl->spdlogger->log(level_e, msg); + } +} + +RAFT_INLINE_CONDITIONAL void logger::flush() { pimpl->spdlogger->flush(); } + +RAFT_INLINE_CONDITIONAL logger::~logger() {} + +}; // namespace raft diff --git a/cpp/include/raft/core/logger-macros.hpp b/cpp/include/raft/core/logger-macros.hpp new file mode 100644 index 0000000000..5ddb072067 --- /dev/null +++ b/cpp/include/raft/core/logger-macros.hpp @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +/** + * @defgroup logging levels used in raft + * + * @note exactly match the corresponding ones (but reverse in terms of value) + * in spdlog for wrapping purposes + * + * @{ + */ +#define RAFT_LEVEL_TRACE 6 +#define RAFT_LEVEL_DEBUG 5 +#define RAFT_LEVEL_INFO 4 +#define RAFT_LEVEL_WARN 3 +#define RAFT_LEVEL_ERROR 2 +#define RAFT_LEVEL_CRITICAL 1 +#define RAFT_LEVEL_OFF 0 +/** @} */ + +#if !defined(RAFT_ACTIVE_LEVEL) +#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO +#endif + +/** + * @defgroup loggerMacros Helper macros for dealing with logging + * @{ + */ +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) +#define RAFT_LOG_TRACE_VEC(ptr, len) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + print_vector(#ptr, ptr, len, ss); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) +#define RAFT_LOG_DEBUG(fmt, ...) \ + do { \ + std::stringstream ss; \ + ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ + ss << raft::detail::format(fmt, ##__VA_ARGS__); \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ + } while (0) +#else +#define RAFT_LOG_DEBUG(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) +#define RAFT_LOG_INFO(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_INFO(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) +#define RAFT_LOG_WARN(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_WARN(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) +#define RAFT_LOG_ERROR(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_ERROR(fmt, ...) void(0) +#endif + +#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) +#define RAFT_LOG_CRITICAL(fmt, ...) \ + raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) +#else +#define RAFT_LOG_CRITICAL(fmt, ...) void(0) +#endif +/** @} */ diff --git a/cpp/include/raft/core/logger.hpp b/cpp/include/raft/core/logger.hpp index 3984ec042a..59968ff5e5 100644 --- a/cpp/include/raft/core/logger.hpp +++ b/cpp/include/raft/core/logger.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -15,310 +15,10 @@ */ #pragma once -#ifndef __RAFT_RT_LOGGER -#define __RAFT_RT_LOGGER +#include "logger-macros.hpp" -#include +#include "logger-ext.hpp" -#include - -#include -#include -#include -#include -#include - -#include - -#define SPDLOG_HEADER_ONLY -#include -#include -#include // NOLINT -#include // NOLINT - -/** - * @defgroup logging levels used in raft - * - * @note exactly match the corresponding ones (but reverse in terms of value) - * in spdlog for wrapping purposes - * - * @{ - */ -#define RAFT_LEVEL_TRACE 6 -#define RAFT_LEVEL_DEBUG 5 -#define RAFT_LEVEL_INFO 4 -#define RAFT_LEVEL_WARN 3 -#define RAFT_LEVEL_ERROR 2 -#define RAFT_LEVEL_CRITICAL 1 -#define RAFT_LEVEL_OFF 0 -/** @} */ - -#if !defined(RAFT_ACTIVE_LEVEL) -#define RAFT_ACTIVE_LEVEL RAFT_LEVEL_INFO +#if !defined(RAFT_COMPILED) +#include "logger-inl.hpp" #endif - -namespace raft { - -static const std::string RAFT_NAME = "raft"; -static const std::string default_log_pattern("[%L] [%H:%M:%S.%f] %v"); - -namespace detail { - -/** - * @defgroup CStringFormat Expand a C-style format string - * - * @brief Expands C-style formatted string into std::string - * - * @param[in] fmt format string - * @param[in] vl respective values for each of format modifiers in the string - * - * @return the expanded `std::string` - * - * @{ - */ -inline std::string format(const char* fmt, va_list& vl) -{ - va_list vl_copy; - va_copy(vl_copy, vl); - int length = std::vsnprintf(nullptr, 0, fmt, vl_copy); - assert(length >= 0); - std::vector buf(length + 1); - std::vsnprintf(buf.data(), length + 1, fmt, vl); - return std::string(buf.data()); -} - -inline std::string format(const char* fmt, ...) -{ - va_list vl; - va_start(vl, fmt); - std::string str = format(fmt, vl); - va_end(vl); - return str; -} -/** @} */ - -inline int convert_level_to_spdlog(int level) -{ - level = std::max(RAFT_LEVEL_OFF, std::min(RAFT_LEVEL_TRACE, level)); - return RAFT_LEVEL_TRACE - level; -} - -} // namespace detail - -/** - * @brief The main Logging class for raft library. - * - * This class acts as a thin wrapper over the underlying `spdlog` interface. The - * design is done in this way in order to avoid us having to also ship `spdlog` - * header files in our installation. - * - * @todo This currently only supports logging to stdout. Need to add support in - * future to add custom loggers as well [Issue #2046] - */ -class logger { - public: - // @todo setting the logger once per process with - logger(std::string const& name_ = "") - : sink{std::make_shared()}, - spdlogger{std::make_shared(name_, sink)}, - cur_pattern() - { - set_pattern(default_log_pattern); - set_level(RAFT_ACTIVE_LEVEL); - } - /** - * @brief Singleton method to get the underlying logger object - * - * @return the singleton logger object - */ - static logger& get(std::string const& name = "") - { - if (log_map.find(name) == log_map.end()) { - log_map[name] = std::make_shared(name); - } - return *log_map[name]; - } - - /** - * @brief Set the logging level. - * - * Only messages with level equal or above this will be printed - * - * @param[in] level logging level - * - * @note The log level will actually be set only if the input is within the - * range [RAFT_LEVEL_TRACE, RAFT_LEVEL_OFF]. If it is not, then it'll - * be ignored. See documentation of decisiontree for how this gets used - */ - void set_level(int level) - { - level = raft::detail::convert_level_to_spdlog(level); - spdlogger->set_level(static_cast(level)); - } - - /** - * @brief Set the logging pattern - * - * @param[in] pattern the pattern to be set. Refer this link - * https://github.com/gabime/spdlog/wiki/3.-Custom-formatting - * to know the right syntax of this pattern - */ - void set_pattern(const std::string& pattern) - { - cur_pattern = pattern; - spdlogger->set_pattern(pattern); - } - - /** - * @brief Register a callback function to be run in place of usual log call - * - * @param[in] callback the function to be run on all logged messages - */ - void set_callback(void (*callback)(int lvl, const char* msg)) { sink->set_callback(callback); } - - /** - * @brief Register a flush function compatible with the registered callback - * - * @param[in] flush the function to use when flushing logs - */ - void set_flush(void (*flush)()) { sink->set_flush(flush); } - - /** - * @brief Tells whether messages will be logged for the given log level - * - * @param[in] level log level to be checked for - * @return true if messages will be logged for this level, else false - */ - bool should_log_for(int level) const - { - level = raft::detail::convert_level_to_spdlog(level); - auto level_e = static_cast(level); - return spdlogger->should_log(level_e); - } - - /** - * @brief Query for the current log level - * - * @return the current log level - */ - int get_level() const - { - auto level_e = spdlogger->level(); - return RAFT_LEVEL_TRACE - static_cast(level_e); - } - - /** - * @brief Get the current logging pattern - * @return the pattern - */ - std::string get_pattern() const { return cur_pattern; } - - /** - * @brief Main logging method - * - * @param[in] level logging level of this message - * @param[in] fmt C-like format string, followed by respective params - */ - void log(int level, const char* fmt, ...) - { - level = raft::detail::convert_level_to_spdlog(level); - auto level_e = static_cast(level); - // explicit check to make sure that we only expand messages when required - if (spdlogger->should_log(level_e)) { - va_list vl; - va_start(vl, fmt); - auto msg = raft::detail::format(fmt, vl); - va_end(vl); - spdlogger->log(level_e, msg); - } - } - - /** - * @brief Flush logs by calling flush on underlying logger - */ - void flush() { spdlogger->flush(); } - - ~logger() {} - - private: - logger(); - - static inline std::unordered_map> log_map; - std::shared_ptr sink; - std::shared_ptr spdlogger; - std::string cur_pattern; - int cur_level; -}; // class logger - -}; // namespace raft - -/** - * @defgroup loggerMacros Helper macros for dealing with logging - * @{ - */ -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_TRACE) -#define RAFT_LOG_TRACE_VEC(ptr, len) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - print_vector(#ptr, ptr, len, ss); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_TRACE, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_TRACE_VEC(ptr, len) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_DEBUG) -#define RAFT_LOG_DEBUG(fmt, ...) \ - do { \ - std::stringstream ss; \ - ss << raft::detail::format("%s:%d ", __FILE__, __LINE__); \ - ss << raft::detail::format(fmt, ##__VA_ARGS__); \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_DEBUG, ss.str().c_str()); \ - } while (0) -#else -#define RAFT_LOG_DEBUG(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_INFO) -#define RAFT_LOG_INFO(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_INFO, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_INFO(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_WARN) -#define RAFT_LOG_WARN(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_WARN, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_WARN(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_ERROR) -#define RAFT_LOG_ERROR(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_ERROR, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_ERROR(fmt, ...) void(0) -#endif - -#if (RAFT_ACTIVE_LEVEL >= RAFT_LEVEL_CRITICAL) -#define RAFT_LOG_CRITICAL(fmt, ...) \ - raft::logger::get(RAFT_NAME).log(RAFT_LEVEL_CRITICAL, fmt, ##__VA_ARGS__) -#else -#define RAFT_LOG_CRITICAL(fmt, ...) void(0) -#endif -/** @} */ - -#endif \ No newline at end of file diff --git a/cpp/include/raft/util/cudart_utils.hpp b/cpp/include/raft/util/cudart_utils.hpp index 1134513587..f3b083ac4a 100644 --- a/cpp/include/raft/util/cudart_utils.hpp +++ b/cpp/include/raft/util/cudart_utils.hpp @@ -18,10 +18,9 @@ #include #include +#include + #include -#include -#include -#include #include #include @@ -451,51 +450,4 @@ constexpr inline auto upper_bound() -> half return static_cast(__half_constexpr{0x7c00u}); } -/** - * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned - * unique pointer. - * - * This function is useful in the code where multiple repeated allocations/deallocations are - * expected. - * Use case example: - * @code{.cpp} - * void my_func(..., size_t n, rmm::mr::device_memory_resource* mr = nullptr) { - * auto pool_guard = raft::get_pool_memory_resource(mr, 2 * n * sizeof(float)); - * if (pool_guard){ - * RAFT_LOG_INFO("Created a pool %zu bytes", pool_guard->pool_size()); - * } else { - * RAFT_LOG_INFO("Using the current default or explicitly passed device memory resource"); - * } - * rmm::device_uvector x(n, stream, mr); - * rmm::device_uvector y(n, stream, mr); - * ... - * } - * @endcode - * Here, the new memory resource would be created within the function scope if the passed `mr` is - * null and the default resource is not a pool. After the call, `mr` contains a valid memory - * resource in any case. - * - * @param[inout] mr if not null do nothing; otherwise get the current device resource and wrap it - * into a `pool_memory_resource` if necessary and return the pointer to the result. - * @param initial_size if a new memory pool is created, this would be its initial size (rounded up - * to 256 bytes). - * - * @return if a new memory pool is created, it returns a unique_ptr to it; - * this managed pointer controls the lifetime of the created memory resource. - */ -inline auto get_pool_memory_resource(rmm::mr::device_memory_resource*& mr, size_t initial_size) -{ - using pool_res_t = rmm::mr::pool_memory_resource; - std::unique_ptr pool_res{}; - if (mr) return pool_res; - mr = rmm::mr::get_current_device_resource(); - if (!dynamic_cast(mr) && - !dynamic_cast*>(mr) && - !dynamic_cast*>(mr)) { - pool_res = std::make_unique(mr, (initial_size + 255) & (~255)); - mr = pool_res.get(); - } - return pool_res; -} - } // namespace raft diff --git a/cpp/include/raft/util/inline.hpp b/cpp/include/raft/util/inline.hpp new file mode 100644 index 0000000000..1b625a8a72 --- /dev/null +++ b/cpp/include/raft/util/inline.hpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef RAFT_COMPILED +#define RAFT_INLINE_CONDITIONAL +#else +#define RAFT_INLINE_CONDITIONAL inline +#endif // RAFT_COMPILED diff --git a/cpp/include/raft/util/memory_pool-ext.hpp b/cpp/include/raft/util/memory_pool-ext.hpp new file mode 100644 index 0000000000..a02908346b --- /dev/null +++ b/cpp/include/raft/util/memory_pool-ext.hpp @@ -0,0 +1,27 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include // size_t +#include // std::unique_ptr +#include // rmm::mr::device_memory_resource + +namespace raft { + +std::unique_ptr get_pool_memory_resource( + rmm::mr::device_memory_resource*& mr, size_t initial_size); + +} // namespace raft diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp new file mode 100644 index 0000000000..7968779e3d --- /dev/null +++ b/cpp/include/raft/util/memory_pool-inl.hpp @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once +#include +#include + +#include // RAFT_INLINE_CONDITIONAL +#include +#include +#include + +namespace raft { + +/** + * @brief Get a pointer to a pooled memory resource within the scope of the lifetime of the returned + * unique pointer. + * + * This function is useful in the code where multiple repeated allocations/deallocations are + * expected. + * Use case example: + * @code{.cpp} + * void my_func(..., size_t n, rmm::mr::device_memory_resource* mr = nullptr) { + * auto pool_guard = raft::get_pool_memory_resource(mr, 2 * n * sizeof(float)); + * if (pool_guard){ + * RAFT_LOG_INFO("Created a pool %zu bytes", pool_guard->pool_size()); + * } else { + * RAFT_LOG_INFO("Using the current default or explicitly passed device memory resource"); + * } + * rmm::device_uvector x(n, stream, mr); + * rmm::device_uvector y(n, stream, mr); + * ... + * } + * @endcode + * Here, the new memory resource would be created within the function scope if the passed `mr` is + * null and the default resource is not a pool. After the call, `mr` contains a valid memory + * resource in any case. + * + * @param[inout] mr if not null do nothing; otherwise get the current device resource and wrap it + * into a `pool_memory_resource` if necessary and return the pointer to the result. + * @param initial_size if a new memory pool is created, this would be its initial size (rounded up + * to 256 bytes). + * + * @return if a new memory pool is created, it returns a unique_ptr to it; + * this managed pointer controls the lifetime of the created memory resource. + */ +RAFT_INLINE_CONDITIONAL std::unique_ptr get_pool_memory_resource( + rmm::mr::device_memory_resource*& mr, size_t initial_size) +{ + using pool_res_t = rmm::mr::pool_memory_resource; + std::unique_ptr pool_res{}; + if (mr) return pool_res; + mr = rmm::mr::get_current_device_resource(); + if (!dynamic_cast(mr) && + !dynamic_cast*>(mr) && + !dynamic_cast*>(mr)) { + pool_res = std::make_unique(mr, (initial_size + 255) & (~255)); + mr = pool_res.get(); + } + return pool_res; +} + +} // namespace raft diff --git a/cpp/include/raft/util/memory_pool.hpp b/cpp/include/raft/util/memory_pool.hpp new file mode 100644 index 0000000000..c9d25ecb1f --- /dev/null +++ b/cpp/include/raft/util/memory_pool.hpp @@ -0,0 +1,23 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "memory_pool-ext.hpp" + +#if !defined(RAFT_COMPILED) +#include "memory_pool-inl.hpp" +#endif // RAFT_COMPILED diff --git a/cpp/src/core/logger.cpp b/cpp/src/core/logger.cpp new file mode 100644 index 0000000000..8f81cf2926 --- /dev/null +++ b/cpp/src/core/logger.cpp @@ -0,0 +1,16 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include diff --git a/cpp/src/util/memory_pool.cpp b/cpp/src/util/memory_pool.cpp new file mode 100644 index 0000000000..837e870043 --- /dev/null +++ b/cpp/src/util/memory_pool.cpp @@ -0,0 +1,17 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 0772640d3f..855a0cce84 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -163,12 +163,16 @@ if(BUILD_TESTS) test/ext_headers/raft_spatial_knn_detail_fused_l2_knn.cu test/ext_headers/raft_distance_fused_l2_nn.cu test/ext_headers/raft_neighbors_ivf_pq.cu + test/ext_headers/raft_util_memory_pool.cpp test/ext_headers/raft_neighbors_ivf_flat.cu + test/ext_headers/raft_core_logger.cpp test/ext_headers/raft_neighbors_refine.cu test/ext_headers/raft_neighbors_detail_ivf_flat_search.cu test/ext_headers/raft_neighbors_detail_selection_faiss.cu test/ext_headers/raft_linalg_detail_coalesced_reduction.cu test/ext_headers/raft_spatial_knn_detail_ball_cover_registers.cu + test/ext_headers/raft_neighbors_detail_ivf_flat_interleaved_scan.cu + test/ext_headers/raft_neighbors_detail_ivf_pq_compute_similarity.cu ) # Test that the split headers compile in isolation with: diff --git a/cpp/test/ext_headers/00_generate.py b/cpp/test/ext_headers/00_generate.py index f6631144f4..15f90e1cc5 100644 --- a/cpp/test/ext_headers/00_generate.py +++ b/cpp/test/ext_headers/00_generate.py @@ -49,7 +49,9 @@ "raft/spatial/knn/detail/fused_l2_knn-ext.cuh", "raft/distance/fused_l2_nn-ext.cuh", "raft/neighbors/ivf_pq-ext.cuh", + "raft/util/memory_pool-ext.hpp", "raft/neighbors/ivf_flat-ext.cuh", + "raft/core/logger-ext.hpp", "raft/neighbors/refine-ext.cuh", "raft/neighbors/detail/ivf_flat_search-ext.cuh", "raft/neighbors/detail/selection_faiss-ext.cuh", diff --git a/cpp/test/ext_headers/raft_core_logger.cpp b/cpp/test/ext_headers/raft_core_logger.cpp new file mode 100644 index 0000000000..18ba9ef48d --- /dev/null +++ b/cpp/test/ext_headers/raft_core_logger.cpp @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include diff --git a/cpp/test/ext_headers/raft_util_memory_pool.cpp b/cpp/test/ext_headers/raft_util_memory_pool.cpp new file mode 100644 index 0000000000..11a024b958 --- /dev/null +++ b/cpp/test/ext_headers/raft_util_memory_pool.cpp @@ -0,0 +1,27 @@ + +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * NOTE: this file is generated by 00_generate.py + * + * Make changes there and run in this directory: + * + * > python 00_generate.py + * + */ + +#include From 273e3717154f49140ce0a560525cbcc956f7b740 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Wed, 26 Apr 2023 11:14:45 +0200 Subject: [PATCH 29/32] Add documentation to RAFT_INLINE_CONDITIONAL --- cpp/include/raft/util/inline.hpp | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/cpp/include/raft/util/inline.hpp b/cpp/include/raft/util/inline.hpp index 1b625a8a72..0b1e3090f2 100644 --- a/cpp/include/raft/util/inline.hpp +++ b/cpp/include/raft/util/inline.hpp @@ -16,6 +16,20 @@ #pragma once +// The RAFT_INLINE_CONDITIONAL is a conditional inline specifier that removes +// the inline specification when RAFT_COMPILED is defined. +// +// When RAFT_COMPILED is not defined, functions may be defined in multiple +// translation units and we do not want that to lead to linker errors. +// +// When RAFT_COMPILED is defined, this serves two purposes: +// +// 1. It triggers a multiple definition error message when memory_pool-inl.hpp +// (for instance) is accidentally included in multiple translation units. +// +// 2. We function definitions to be non-inline, because non-inline functions +// symbols are always exported in the object symbol table. For inline functions, +// the compiler may elide the external symbol, which results in linker errors. #ifdef RAFT_COMPILED #define RAFT_INLINE_CONDITIONAL #else From e2b5c18bbe9ee870d8654b7e98d6c2fed53fddaa Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 27 Apr 2023 11:02:54 +0200 Subject: [PATCH 30/32] logger: Do not include cudart_utils.hpp --- cpp/include/raft/core/logger-inl.hpp | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/core/logger-inl.hpp b/cpp/include/raft/core/logger-inl.hpp index 85ea4baea5..e23de66bec 100644 --- a/cpp/include/raft/core/logger-inl.hpp +++ b/cpp/include/raft/core/logger-inl.hpp @@ -35,7 +35,6 @@ #define SPDLOG_HEADER_ONLY #include -#include #include // RAFT_INLINE_CONDITIONAL #include // NOLINT #include // NOLINT From 22bade9ccb50992f6badc8f7cbc5d06e3b36a030 Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 27 Apr 2023 12:04:20 +0200 Subject: [PATCH 31/32] Move RAFT_INLINE_CONDITIONAL to macros.hpp --- cpp/include/raft/core/detail/macros.hpp | 22 +++++++++++++- cpp/include/raft/core/logger-inl.hpp | 2 +- cpp/include/raft/util/inline.hpp | 37 ----------------------- cpp/include/raft/util/memory_pool-inl.hpp | 2 +- 4 files changed, 23 insertions(+), 40 deletions(-) delete mode 100644 cpp/include/raft/util/inline.hpp diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index bfb47437ad..fb0d597850 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022, NVIDIA CORPORATION. + * Copyright (c) 2022-2023, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -40,6 +40,26 @@ #define RAFT_INLINE_FUNCTION _RAFT_HOST_DEVICE _RAFT_FORCEINLINE #endif +// The RAFT_INLINE_CONDITIONAL is a conditional inline specifier that removes +// the inline specification when RAFT_COMPILED is defined. +// +// When RAFT_COMPILED is not defined, functions may be defined in multiple +// translation units and we do not want that to lead to linker errors. +// +// When RAFT_COMPILED is defined, this serves two purposes: +// +// 1. It triggers a multiple definition error message when memory_pool-inl.hpp +// (for instance) is accidentally included in multiple translation units. +// +// 2. We function definitions to be non-inline, because non-inline functions +// symbols are always exported in the object symbol table. For inline functions, +// the compiler may elide the external symbol, which results in linker errors. +#ifdef RAFT_COMPILED +#define RAFT_INLINE_CONDITIONAL +#else +#define RAFT_INLINE_CONDITIONAL inline +#endif // RAFT_COMPILED + /** * Some macro magic to remove optional parentheses of a macro argument. * See https://stackoverflow.com/a/62984543 diff --git a/cpp/include/raft/core/logger-inl.hpp b/cpp/include/raft/core/logger-inl.hpp index e23de66bec..a90023d01f 100644 --- a/cpp/include/raft/core/logger-inl.hpp +++ b/cpp/include/raft/core/logger-inl.hpp @@ -35,7 +35,7 @@ #define SPDLOG_HEADER_ONLY #include -#include // RAFT_INLINE_CONDITIONAL +#include // RAFT_INLINE_CONDITIONAL #include // NOLINT #include // NOLINT diff --git a/cpp/include/raft/util/inline.hpp b/cpp/include/raft/util/inline.hpp deleted file mode 100644 index 0b1e3090f2..0000000000 --- a/cpp/include/raft/util/inline.hpp +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#pragma once - -// The RAFT_INLINE_CONDITIONAL is a conditional inline specifier that removes -// the inline specification when RAFT_COMPILED is defined. -// -// When RAFT_COMPILED is not defined, functions may be defined in multiple -// translation units and we do not want that to lead to linker errors. -// -// When RAFT_COMPILED is defined, this serves two purposes: -// -// 1. It triggers a multiple definition error message when memory_pool-inl.hpp -// (for instance) is accidentally included in multiple translation units. -// -// 2. We function definitions to be non-inline, because non-inline functions -// symbols are always exported in the object symbol table. For inline functions, -// the compiler may elide the external symbol, which results in linker errors. -#ifdef RAFT_COMPILED -#define RAFT_INLINE_CONDITIONAL -#else -#define RAFT_INLINE_CONDITIONAL inline -#endif // RAFT_COMPILED diff --git a/cpp/include/raft/util/memory_pool-inl.hpp b/cpp/include/raft/util/memory_pool-inl.hpp index 7968779e3d..a227b6e53f 100644 --- a/cpp/include/raft/util/memory_pool-inl.hpp +++ b/cpp/include/raft/util/memory_pool-inl.hpp @@ -18,7 +18,7 @@ #include #include -#include // RAFT_INLINE_CONDITIONAL +#include // RAFT_INLINE_CONDITIONAL #include #include #include From 6a0be4c3f1eaf446699f3ee033973e059b74160f Mon Sep 17 00:00:00 2001 From: Allard Hendriksen Date: Thu, 27 Apr 2023 12:05:33 +0200 Subject: [PATCH 32/32] Define RAFT_WEAK_FUNCTION --- cpp/include/raft/core/detail/macros.hpp | 14 ++++++++++++++ .../detail/ivf_pq_compute_similarity-ext.cuh | 4 +++- .../detail/ivf_pq_compute_similarity-inl.cuh | 2 +- 3 files changed, 18 insertions(+), 2 deletions(-) diff --git a/cpp/include/raft/core/detail/macros.hpp b/cpp/include/raft/core/detail/macros.hpp index fb0d597850..390acea697 100644 --- a/cpp/include/raft/core/detail/macros.hpp +++ b/cpp/include/raft/core/detail/macros.hpp @@ -60,6 +60,20 @@ #define RAFT_INLINE_CONDITIONAL inline #endif // RAFT_COMPILED +// The RAFT_WEAK_FUNCTION specificies that: +// +// 1. A function may be defined in multiple translation units (like inline) +// +// 2. Must still emit an external symbol (unlike inline). This enables declaring +// a function signature in an `-ext` header and defining it in a source file. +// +// From +// https://gcc.gnu.org/onlinedocs/gcc/Common-Function-Attributes.html#Common-Function-Attributes: +// +// "The weak attribute causes a declaration of an external symbol to be emitted +// as a weak symbol rather than a global." +#define RAFT_WEAK_FUNCTION __attribute__((weak)) + /** * Some macro magic to remove optional parentheses of a macro argument. * See https://stackoverflow.com/a/62984543 diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh index 0d5ca90297..41e9fda701 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_compute_similarity-ext.cuh @@ -17,6 +17,7 @@ #pragma once #include // __half +#include // RAFT_WEAK_FUNCTION #include // raft::distance::DistanceType #include // raft::neighbors::ivf_pq::detail::fp_8bit #include // raft::neighbors::ivf_pq::codebook_gen @@ -30,7 +31,8 @@ namespace raft::neighbors::ivf_pq::detail { // is_local_topk_feasible is not inline here, because we would have to define it // here as well. That would run the risk of the definitions here and in the // -inl.cuh header diverging. -auto is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool; +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) + -> bool; template = 32) && !(kMaxCapacity & (kMaxCapacity - 1)), "kMaxCapacity must be a power of two, not smaller than the WarpSize."); // using weak attribute here, because it may be compiled multiple times. -auto __attribute__((weak)) is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) +auto RAFT_WEAK_FUNCTION is_local_topk_feasible(uint32_t k, uint32_t n_probes, uint32_t n_queries) -> bool { if (k > kMaxCapacity) { return false; } // warp_sort not possible