From 6b30de94f54f7b2bb0cf523d767850b6a3472e2e Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 9 Mar 2023 20:43:28 -0500 Subject: [PATCH] Using int64_t specializations for `ivf_pq` and `refine` (#1325) Since FAISS and Milvus are both using `int64_t` everywhere and we haven't explicitly encountered places where `uint64_t` is being used for ann indices, we're opting to just use `int64_t` instead. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1325 --- cpp/CMakeLists.txt | 48 +++---- cpp/bench/CMakeLists.txt | 10 +- ..._t_uint64_t.cu => ivf_pq_float_int64_t.cu} | 2 +- ...t_uint64_t.cu => ivf_pq_int8_t_int64_t.cu} | 2 +- ..._uint64_t.cu => ivf_pq_uint8_t_int64_t.cu} | 2 +- ...at_uint64_t.cu => refine_float_int64_t.cu} | 4 +- ..._uint64_t.cu => refine_uint8_t_int64_t.cu} | 4 +- cpp/include/raft/neighbors/ivf_pq.cuh | 1 + .../raft/neighbors/specializations/ivf_pq.cuh | 6 +- .../raft/neighbors/specializations/refine.cuh | 40 +++--- .../raft/spatial/knn/detail/ann_quantized.cuh | 2 + cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 16 +-- cpp/include/raft_runtime/neighbors/refine.hpp | 36 +++--- cpp/src/distance/neighbors/ivfpq_build.cu | 6 +- .../distance/neighbors/ivfpq_deserialize.cu | 4 +- ...t64_t.cu => ivfpq_search_float_int64_t.cu} | 2 +- ...64_t.cu => ivfpq_search_int8_t_int64_t.cu} | 2 +- ...4_t.cu => ivfpq_search_uint8_t_int64_t.cu} | 2 +- cpp/src/distance/neighbors/ivfpq_serialize.cu | 2 +- ..._t_int8_t.cu => refine_d_int64_t_float.cu} | 12 +- ..._uint8_t.cu => refine_d_int64_t_int8_t.cu} | 12 +- ...t_float.cu => refine_d_int64_t_uint8_t.cu} | 13 +- .../neighbors/refine_d_uint64_t_int8_t.cu | 34 ----- .../neighbors/refine_d_uint64_t_uint8_t.cu | 34 ----- ...4_t_float.cu => refine_h_int64_t_float.cu} | 13 +- .../neighbors/refine_h_int64_t_int8_t.cu | 34 +++++ .../neighbors/refine_h_int64_t_uint8_t.cu | 34 +++++ ...nt64_t.cu => ivfpq_build_float_int64_t.cu} | 2 +- ...t64_t.cu => ivfpq_build_int8_t_int64_t.cu} | 2 +- ...64_t.cu => ivfpq_build_uint8_t_int64_t.cu} | 2 +- ...t64_t.cu => ivfpq_extend_float_int64_t.cu} | 2 +- ...64_t.cu => ivfpq_extend_int8_t_int64_t.cu} | 2 +- ...4_t.cu => ivfpq_extend_uint8_t_int64_t.cu} | 2 +- ...t64_t.cu => ivfpq_search_float_int64_t.cu} | 2 +- ...64_t.cu => ivfpq_search_int8_t_int64_t.cu} | 2 +- ...4_t.cu => ivfpq_search_uint8_t_int64_t.cu} | 2 +- ..._t_int8_t.cu => refine_d_int64_t_float.cu} | 13 +- ..._uint8_t.cu => refine_d_int64_t_int8_t.cu} | 12 +- ...t_float.cu => refine_d_int64_t_uint8_t.cu} | 12 +- .../refine_d_uint64_t_int8_t.cu | 30 ----- .../refine_d_uint64_t_uint8_t.cu | 30 ----- ...4_t_float.cu => refine_h_int64_t_float.cu} | 12 +- .../refine_h_int64_t_int8_t.cu | 29 +++++ .../refine_h_int64_t_uint8_t.cu | 30 +++++ cpp/test/CMakeLists.txt | 6 +- .../ann_ivf_pq/test_float_int64_t.cu | 5 +- ...8_t_uint64_t.cu => test_int8_t_int64_t.cu} | 8 +- ...at_uint64_t.cu => test_uint8_t_int64_t.cu} | 8 +- .../ann_ivf_pq/test_uint8_t_uint64_t.cu | 27 ---- cpp/test/neighbors/refine.cu | 20 +-- python/pylibraft/pylibraft/common/mdspan.pxd | 10 +- python/pylibraft/pylibraft/common/mdspan.pyx | 32 ++--- .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 67 +++++----- .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 36 +++--- .../pylibraft/pylibraft/neighbors/refine.pyx | 120 +++++++++--------- .../pylibraft/pylibraft/test/test_ivf_pq.py | 12 +- .../pylibraft/pylibraft/test/test_refine.py | 8 +- 57 files changed, 441 insertions(+), 481 deletions(-) rename cpp/bench/neighbors/knn/{ivf_pq_uint8_t_uint64_t.cu => ivf_pq_float_int64_t.cu} (89%) rename cpp/bench/neighbors/knn/{ivf_pq_float_uint64_t.cu => ivf_pq_int8_t_int64_t.cu} (91%) rename cpp/bench/neighbors/knn/{ivf_pq_int8_t_uint64_t.cu => ivf_pq_uint8_t_int64_t.cu} (91%) rename cpp/bench/neighbors/{refine_float_uint64_t.cu => refine_float_int64_t.cu} (88%) rename cpp/bench/neighbors/{refine_uint8_t_uint64_t.cu => refine_uint8_t_int64_t.cu} (88%) rename cpp/src/distance/neighbors/{ivfpq_search_float_uint64_t.cu => ivfpq_search_float_int64_t.cu} (97%) rename cpp/src/distance/neighbors/{ivfpq_search_int8_t_uint64_t.cu => ivfpq_search_int8_t_int64_t.cu} (97%) rename cpp/src/distance/neighbors/{ivfpq_search_uint8_t_uint64_t.cu => ivfpq_search_uint8_t_int64_t.cu} (97%) rename cpp/src/distance/neighbors/{refine_h_uint64_t_int8_t.cu => refine_d_int64_t_float.cu} (72%) rename cpp/src/distance/neighbors/{refine_h_uint64_t_uint8_t.cu => refine_d_int64_t_int8_t.cu} (72%) rename cpp/src/distance/neighbors/{refine_h_uint64_t_float.cu => refine_d_int64_t_uint8_t.cu} (70%) delete mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu delete mode 100644 cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu rename cpp/src/distance/neighbors/{refine_d_uint64_t_float.cu => refine_h_int64_t_float.cu} (67%) create mode 100644 cpp/src/distance/neighbors/refine_h_int64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/refine_h_int64_t_uint8_t.cu rename cpp/src/distance/neighbors/specializations/{ivfpq_build_float_uint64_t.cu => ivfpq_build_float_int64_t.cu} (96%) rename cpp/src/distance/neighbors/specializations/{ivfpq_build_int8_t_uint64_t.cu => ivfpq_build_int8_t_int64_t.cu} (96%) rename cpp/src/distance/neighbors/specializations/{ivfpq_build_uint8_t_uint64_t.cu => ivfpq_build_uint8_t_int64_t.cu} (96%) rename cpp/src/distance/neighbors/specializations/{ivfpq_extend_float_uint64_t.cu => ivfpq_extend_float_int64_t.cu} (98%) rename cpp/src/distance/neighbors/specializations/{ivfpq_extend_int8_t_uint64_t.cu => ivfpq_extend_int8_t_int64_t.cu} (97%) rename cpp/src/distance/neighbors/specializations/{ivfpq_extend_uint8_t_uint64_t.cu => ivfpq_extend_uint8_t_int64_t.cu} (97%) rename cpp/src/distance/neighbors/specializations/{ivfpq_search_float_uint64_t.cu => ivfpq_search_float_int64_t.cu} (97%) rename cpp/src/distance/neighbors/specializations/{ivfpq_search_int8_t_uint64_t.cu => ivfpq_search_int8_t_int64_t.cu} (97%) rename cpp/src/distance/neighbors/specializations/{ivfpq_search_uint8_t_uint64_t.cu => ivfpq_search_uint8_t_int64_t.cu} (97%) rename cpp/src/distance/neighbors/specializations/{refine_h_uint64_t_int8_t.cu => refine_d_int64_t_float.cu} (68%) rename cpp/src/distance/neighbors/specializations/{refine_h_uint64_t_uint8_t.cu => refine_d_int64_t_int8_t.cu} (68%) rename cpp/src/distance/neighbors/specializations/{refine_h_uint64_t_float.cu => refine_d_int64_t_uint8_t.cu} (67%) delete mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu delete mode 100644 cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu rename cpp/src/distance/neighbors/specializations/{refine_d_uint64_t_float.cu => refine_h_int64_t_float.cu} (65%) create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_int64_t_int8_t.cu create mode 100644 cpp/src/distance/neighbors/specializations/refine_h_int64_t_uint8_t.cu rename cpp/test/neighbors/ann_ivf_pq/{test_int8_t_uint64_t.cu => test_int8_t_int64_t.cu} (79%) rename cpp/test/neighbors/ann_ivf_pq/{test_float_uint64_t.cu => test_uint8_t_int64_t.cu} (77%) delete mode 100644 cpp/test/neighbors/ann_ivf_pq/test_uint8_t_uint64_t.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 7e5b10b227..9768dc266c 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -299,18 +299,18 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/cluster/update_centroids_double.cu src/distance/cluster/cluster_cost_float.cu src/distance/cluster/cluster_cost_double.cu - src/distance/neighbors/refine_d_uint64_t_float.cu - src/distance/neighbors/refine_d_uint64_t_int8_t.cu - src/distance/neighbors/refine_d_uint64_t_uint8_t.cu - src/distance/neighbors/refine_h_uint64_t_float.cu - src/distance/neighbors/refine_h_uint64_t_int8_t.cu - src/distance/neighbors/refine_h_uint64_t_uint8_t.cu - src/distance/neighbors/specializations/refine_d_uint64_t_float.cu - src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu - src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu - src/distance/neighbors/specializations/refine_h_uint64_t_float.cu - src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu - src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu + src/distance/neighbors/refine_d_int64_t_float.cu + src/distance/neighbors/refine_d_int64_t_int8_t.cu + src/distance/neighbors/refine_d_int64_t_uint8_t.cu + src/distance/neighbors/refine_h_int64_t_float.cu + src/distance/neighbors/refine_h_int64_t_int8_t.cu + src/distance/neighbors/refine_h_int64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_d_int64_t_float.cu + src/distance/neighbors/specializations/refine_d_int64_t_int8_t.cu + src/distance/neighbors/specializations/refine_d_int64_t_uint8_t.cu + src/distance/neighbors/specializations/refine_h_int64_t_float.cu + src/distance/neighbors/specializations/refine_h_int64_t_int8_t.cu + src/distance/neighbors/specializations/refine_h_int64_t_uint8_t.cu src/distance/cluster/kmeans_fit_float.cu src/distance/cluster/kmeans_fit_double.cu src/distance/cluster/kmeans_init_plus_plus_double.cu @@ -367,18 +367,18 @@ if(RAFT_COMPILE_DIST_LIBRARY) src/distance/neighbors/ivfpq_build.cu src/distance/neighbors/ivfpq_deserialize.cu src/distance/neighbors/ivfpq_serialize.cu - src/distance/neighbors/ivfpq_search_float_uint64_t.cu - src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu - src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu - src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu + src/distance/neighbors/ivfpq_search_float_int64_t.cu + src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu + src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu + src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu + src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu + src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu + src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu + src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu src/distance/neighbors/specializations/detail/compute_similarity_float_float_fast.cu src/distance/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu src/distance/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu diff --git a/cpp/bench/CMakeLists.txt b/cpp/bench/CMakeLists.txt index f54be94068..68020e4ed3 100644 --- a/cpp/bench/CMakeLists.txt +++ b/cpp/bench/CMakeLists.txt @@ -125,11 +125,11 @@ if(BUILD_BENCH) bench/neighbors/knn/ivf_flat_float_int64_t.cu bench/neighbors/knn/ivf_flat_int8_t_int64_t.cu bench/neighbors/knn/ivf_flat_uint8_t_int64_t.cu - bench/neighbors/knn/ivf_pq_float_uint64_t.cu - bench/neighbors/knn/ivf_pq_int8_t_uint64_t.cu - bench/neighbors/knn/ivf_pq_uint8_t_uint64_t.cu - bench/neighbors/refine_float_uint64_t.cu - bench/neighbors/refine_uint8_t_uint64_t.cu + bench/neighbors/knn/ivf_pq_float_int64_t.cu + bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu + bench/neighbors/knn/ivf_pq_uint8_t_int64_t.cu + bench/neighbors/refine_float_int64_t.cu + bench/neighbors/refine_uint8_t_int64_t.cu bench/main.cpp OPTIONAL DIST diff --git a/cpp/bench/neighbors/knn/ivf_pq_uint8_t_uint64_t.cu b/cpp/bench/neighbors/knn/ivf_pq_float_int64_t.cu similarity index 89% rename from cpp/bench/neighbors/knn/ivf_pq_uint8_t_uint64_t.cu rename to cpp/bench/neighbors/knn/ivf_pq_float_int64_t.cu index a898f0523e..83c4973c3a 100644 --- a/cpp/bench/neighbors/knn/ivf_pq_uint8_t_uint64_t.cu +++ b/cpp/bench/neighbors/knn/ivf_pq_float_int64_t.cu @@ -18,6 +18,6 @@ namespace raft::bench::spatial { -KNN_REGISTER(uint8_t, uint32_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); +KNN_REGISTER(float, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); } // namespace raft::bench::spatial diff --git a/cpp/bench/neighbors/knn/ivf_pq_float_uint64_t.cu b/cpp/bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu similarity index 91% rename from cpp/bench/neighbors/knn/ivf_pq_float_uint64_t.cu rename to cpp/bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu index d25b48a6ed..4ea281b11a 100644 --- a/cpp/bench/neighbors/knn/ivf_pq_float_uint64_t.cu +++ b/cpp/bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu @@ -18,6 +18,6 @@ namespace raft::bench::spatial { -KNN_REGISTER(float, uint64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); +KNN_REGISTER(int8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); } // namespace raft::bench::spatial diff --git a/cpp/bench/neighbors/knn/ivf_pq_int8_t_uint64_t.cu b/cpp/bench/neighbors/knn/ivf_pq_uint8_t_int64_t.cu similarity index 91% rename from cpp/bench/neighbors/knn/ivf_pq_int8_t_uint64_t.cu rename to cpp/bench/neighbors/knn/ivf_pq_uint8_t_int64_t.cu index 3e38fbdae8..3313a49ba2 100644 --- a/cpp/bench/neighbors/knn/ivf_pq_int8_t_uint64_t.cu +++ b/cpp/bench/neighbors/knn/ivf_pq_uint8_t_int64_t.cu @@ -18,6 +18,6 @@ namespace raft::bench::spatial { -KNN_REGISTER(int8_t, uint64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); +KNN_REGISTER(uint8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes); } // namespace raft::bench::spatial diff --git a/cpp/bench/neighbors/refine_float_uint64_t.cu b/cpp/bench/neighbors/refine_float_int64_t.cu similarity index 88% rename from cpp/bench/neighbors/refine_float_uint64_t.cu rename to cpp/bench/neighbors/refine_float_int64_t.cu index f0846685ab..1ff59b5d4d 100644 --- a/cpp/bench/neighbors/refine_float_uint64_t.cu +++ b/cpp/bench/neighbors/refine_float_int64_t.cu @@ -29,6 +29,6 @@ using namespace raft::neighbors; namespace raft::bench::neighbors { -using refine_float_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); +using refine_float_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/bench/neighbors/refine_uint8_t_uint64_t.cu b/cpp/bench/neighbors/refine_uint8_t_int64_t.cu similarity index 88% rename from cpp/bench/neighbors/refine_uint8_t_uint64_t.cu rename to cpp/bench/neighbors/refine_uint8_t_int64_t.cu index f6e1fe7e48..92806f84a7 100644 --- a/cpp/bench/neighbors/refine_uint8_t_uint64_t.cu +++ b/cpp/bench/neighbors/refine_uint8_t_int64_t.cu @@ -29,6 +29,6 @@ using namespace raft::neighbors; namespace raft::bench::neighbors { -using refine_uint8_int64 = RefineAnn; -RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); +using refine_uint8_int64 = RefineAnn; +RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs()); } // namespace raft::bench::neighbors diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index 4bb617b526..db60af847a 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -21,6 +21,7 @@ #include #include +#include #include #include diff --git a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh index 3ff99fb4da..142cc7806f 100644 --- a/cpp/include/raft/neighbors/specializations/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/specializations/ivf_pq.cuh @@ -49,9 +49,9 @@ namespace raft::neighbors::ivf_pq { IdxT*, \ float*, \ rmm::mr::device_memory_resource*); -RAFT_INST(float, uint64_t); -RAFT_INST(int8_t, uint64_t); -RAFT_INST(uint8_t, uint64_t); +RAFT_INST(float, int64_t); +RAFT_INST(int8_t, int64_t); +RAFT_INST(uint8_t, int64_t); #undef RAFT_INST diff --git a/cpp/include/raft/neighbors/specializations/refine.cuh b/cpp/include/raft/neighbors/specializations/refine.cuh index 71e83a26f3..aef4834c9f 100644 --- a/cpp/include/raft/neighbors/specializations/refine.cuh +++ b/cpp/include/raft/neighbors/specializations/refine.cuh @@ -24,28 +24,28 @@ namespace raft::neighbors { #undef RAFT_INST #endif -#define RAFT_INST(T, IdxT) \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric); \ - \ - extern template void refine( \ - raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ +#define RAFT_INST(T, IdxT) \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric); \ + \ + extern template void refine( \ + raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ distance::DistanceType metric); -RAFT_INST(float, uint64_t); -RAFT_INST(uint8_t, uint64_t); -RAFT_INST(int8_t, uint64_t); +RAFT_INST(float, int64_t); +RAFT_INST(uint8_t, int64_t); +RAFT_INST(int8_t, int64_t); #undef RAFT_INST } // namespace raft::neighbors diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh index 066dcaaa6b..a27e36c25d 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized.cuh @@ -29,6 +29,8 @@ #include #include +#include + #include #include diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index e4c228effe..00a97931fb 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -29,9 +29,9 @@ namespace raft::runtime::neighbors::ivf_pq { raft::device_matrix_view, \ raft::device_matrix_view); -RAFT_INST_SEARCH(float, uint64_t); -RAFT_INST_SEARCH(int8_t, uint64_t); -RAFT_INST_SEARCH(uint8_t, uint64_t); +RAFT_INST_SEARCH(float, int64_t); +RAFT_INST_SEARCH(int8_t, int64_t); +RAFT_INST_SEARCH(uint8_t, int64_t); #undef RAFT_INST_SEARCH @@ -60,9 +60,9 @@ RAFT_INST_SEARCH(uint8_t, uint64_t); raft::device_matrix_view new_vectors, \ raft::device_matrix_view new_indices); -RAFT_INST_BUILD_EXTEND(float, uint64_t) -RAFT_INST_BUILD_EXTEND(int8_t, uint64_t) -RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t) +RAFT_INST_BUILD_EXTEND(float, int64_t) +RAFT_INST_BUILD_EXTEND(int8_t, int64_t) +RAFT_INST_BUILD_EXTEND(uint8_t, int64_t) #undef RAFT_INST_BUILD_EXTEND @@ -78,7 +78,7 @@ RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t) */ void serialize(raft::device_resources const& handle, const std::string& filename, - const raft::neighbors::ivf_pq::index& index); + const raft::neighbors::ivf_pq::index& index); /** * Load index from file. @@ -92,6 +92,6 @@ void serialize(raft::device_resources const& handle, */ void deserialize(raft::device_resources const& handle, const std::string& filename, - raft::neighbors::ivf_pq::index* index); + raft::neighbors::ivf_pq::index* index); } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/include/raft_runtime/neighbors/refine.hpp b/cpp/include/raft_runtime/neighbors/refine.hpp index e779d17ded..0171259bbb 100644 --- a/cpp/include/raft_runtime/neighbors/refine.hpp +++ b/cpp/include/raft_runtime/neighbors/refine.hpp @@ -22,26 +22,26 @@ namespace raft::runtime::neighbors { -#define RAFT_INST_REFINE(IDX_T, DATA_T) \ - void refine(raft::device_resources const& handle, \ - raft::device_matrix_view dataset, \ - raft::device_matrix_view queries, \ - raft::device_matrix_view neighbor_candidates, \ - raft::device_matrix_view indices, \ - raft::device_matrix_view distances, \ - distance::DistanceType metric); \ - \ - void refine(raft::device_resources const& handle, \ - raft::host_matrix_view dataset, \ - raft::host_matrix_view queries, \ - raft::host_matrix_view neighbor_candidates, \ - raft::host_matrix_view indices, \ - raft::host_matrix_view distances, \ +#define RAFT_INST_REFINE(IDX_T, DATA_T) \ + void refine(raft::device_resources const& handle, \ + raft::device_matrix_view dataset, \ + raft::device_matrix_view queries, \ + raft::device_matrix_view neighbor_candidates, \ + raft::device_matrix_view indices, \ + raft::device_matrix_view distances, \ + distance::DistanceType metric); \ + \ + void refine(raft::device_resources const& handle, \ + raft::host_matrix_view dataset, \ + raft::host_matrix_view queries, \ + raft::host_matrix_view neighbor_candidates, \ + raft::host_matrix_view indices, \ + raft::host_matrix_view distances, \ distance::DistanceType metric); -RAFT_INST_REFINE(uint64_t, float); -RAFT_INST_REFINE(uint64_t, uint8_t); -RAFT_INST_REFINE(uint64_t, int8_t); +RAFT_INST_REFINE(int64_t, float); +RAFT_INST_REFINE(int64_t, uint8_t); +RAFT_INST_REFINE(int64_t, int8_t); #undef RAFT_INST_REFINE diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 96ba349d1d..bfc893dbd3 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -52,9 +52,9 @@ namespace raft::runtime::neighbors::ivf_pq { raft::neighbors::ivf_pq::extend(handle, idx, new_vectors, new_indices); \ } -RAFT_INST_BUILD_EXTEND(float, uint64_t); -RAFT_INST_BUILD_EXTEND(int8_t, uint64_t); -RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t); +RAFT_INST_BUILD_EXTEND(float, int64_t); +RAFT_INST_BUILD_EXTEND(int8_t, int64_t); +RAFT_INST_BUILD_EXTEND(uint8_t, int64_t); #undef RAFT_INST_BUILD_EXTEND diff --git a/cpp/src/distance/neighbors/ivfpq_deserialize.cu b/cpp/src/distance/neighbors/ivfpq_deserialize.cu index 8f71e5622b..e6b9a2176e 100644 --- a/cpp/src/distance/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/distance/neighbors/ivfpq_deserialize.cu @@ -21,9 +21,9 @@ namespace raft::runtime::neighbors::ivf_pq { void deserialize(raft::device_resources const& handle, const std::string& filename, - raft::neighbors::ivf_pq::index* index) + raft::neighbors::ivf_pq::index* index) { if (!index) { RAFT_FAIL("Invalid index pointer"); } - *index = raft::neighbors::ivf_pq::deserialize(handle, filename); + *index = raft::neighbors::ivf_pq::deserialize(handle, filename); }; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu rename to cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu index 9bd750a2e2..c38e27f196 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_float_int64_t.cu @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::ivf_pq { handle, params, idx, queries, k, neighbors, distances); \ } -RAFT_SEARCH_INST(float, uint64_t); +RAFT_SEARCH_INST(float, int64_t); #undef RAFT_INST_SEARCH diff --git a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu rename to cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu index 303c7009cf..5df3f15eb6 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::ivf_pq { handle, params, idx, queries, k, neighbors, distances); \ } -RAFT_SEARCH_INST(int8_t, uint64_t); +RAFT_SEARCH_INST(int8_t, int64_t); #undef RAFT_INST_SEARCH diff --git a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu rename to cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu index c057abd22e..0293d05246 100644 --- a/cpp/src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::ivf_pq { handle, params, idx, queries, k, neighbors, distances); \ } -RAFT_SEARCH_INST(uint8_t, uint64_t); +RAFT_SEARCH_INST(uint8_t, int64_t); #undef RAFT_INST_SEARCH diff --git a/cpp/src/distance/neighbors/ivfpq_serialize.cu b/cpp/src/distance/neighbors/ivfpq_serialize.cu index b7ceb9150a..711240ec04 100644 --- a/cpp/src/distance/neighbors/ivfpq_serialize.cu +++ b/cpp/src/distance/neighbors/ivfpq_serialize.cu @@ -21,7 +21,7 @@ namespace raft::runtime::neighbors::ivf_pq { void serialize(raft::device_resources const& handle, const std::string& filename, - const raft::neighbors::ivf_pq::index& index) + const raft::neighbors::ivf_pq::index& index) { raft::neighbors::ivf_pq::serialize(handle, filename, index); }; diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_int64_t_float.cu similarity index 72% rename from cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu rename to cpp/src/distance/neighbors/refine_d_int64_t_float.cu index cf6d7a397a..101c756987 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_int64_t_float.cu @@ -20,14 +20,14 @@ namespace raft::runtime::neighbors { void refine(raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, distance::DistanceType metric) { - raft::neighbors::refine( + raft::neighbors::refine( handle, dataset, queries, neighbor_candidates, indices, distances, metric); } diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_int64_t_int8_t.cu similarity index 72% rename from cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu rename to cpp/src/distance/neighbors/refine_d_int64_t_int8_t.cu index e9c4345e97..57cbc8a454 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/refine_d_int64_t_int8_t.cu @@ -20,14 +20,14 @@ namespace raft::runtime::neighbors { void refine(raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, distance::DistanceType metric) { - raft::neighbors::refine( + raft::neighbors::refine( handle, dataset, queries, neighbor_candidates, indices, distances, metric); } diff --git a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_d_int64_t_uint8_t.cu similarity index 70% rename from cpp/src/distance/neighbors/refine_h_uint64_t_float.cu rename to cpp/src/distance/neighbors/refine_d_int64_t_uint8_t.cu index 8549d65dc5..1f9d93cd35 100644 --- a/cpp/src/distance/neighbors/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_d_int64_t_uint8_t.cu @@ -1,4 +1,3 @@ - /* * Copyright (c) 2022-2023, NVIDIA CORPORATION. * @@ -21,14 +20,14 @@ namespace raft::runtime::neighbors { void refine(raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, distance::DistanceType metric) { - raft::neighbors::refine( + raft::neighbors::refine( handle, dataset, queries, neighbor_candidates, indices, distances, metric); } diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu deleted file mode 100644 index 3db07f0cdb..0000000000 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_int8_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::runtime::neighbors { - -void refine(raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric) -{ - raft::neighbors::refine( - handle, dataset, queries, neighbor_candidates, indices, distances, metric); -} - -} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu deleted file mode 100644 index 2ce43d5800..0000000000 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_uint8_t.cu +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include - -namespace raft::runtime::neighbors { - -void refine(raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric) -{ - raft::neighbors::refine( - handle, dataset, queries, neighbor_candidates, indices, distances, metric); -} - -} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/refine_h_int64_t_float.cu similarity index 67% rename from cpp/src/distance/neighbors/refine_d_uint64_t_float.cu rename to cpp/src/distance/neighbors/refine_h_int64_t_float.cu index d7b460180a..9d6f34312a 100644 --- a/cpp/src/distance/neighbors/refine_d_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/refine_h_int64_t_float.cu @@ -1,3 +1,4 @@ + /* * Copyright (c) 2022-2023, NVIDIA CORPORATION. * @@ -20,14 +21,14 @@ namespace raft::runtime::neighbors { void refine(raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, distance::DistanceType metric) { - raft::neighbors::refine( + raft::neighbors::refine( handle, dataset, queries, neighbor_candidates, indices, distances, metric); } diff --git a/cpp/src/distance/neighbors/refine_h_int64_t_int8_t.cu b/cpp/src/distance/neighbors/refine_h_int64_t_int8_t.cu new file mode 100644 index 0000000000..8757b15956 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_int64_t_int8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/refine_h_int64_t_uint8_t.cu b/cpp/src/distance/neighbors/refine_h_int64_t_uint8_t.cu new file mode 100644 index 0000000000..3d6c8aa201 --- /dev/null +++ b/cpp/src/distance/neighbors/refine_h_int64_t_uint8_t.cu @@ -0,0 +1,34 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +namespace raft::runtime::neighbors { + +void refine(raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric) +{ + raft::neighbors::refine( + handle, dataset, queries, neighbor_candidates, indices, distances, metric); +} + +} // namespace raft::runtime::neighbors diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu similarity index 96% rename from cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu index 9563ea8a88..b1ebf65c6a 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu @@ -24,7 +24,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view dataset) \ ->index; -RAFT_MAKE_INSTANCE(float, uint64_t); +RAFT_MAKE_INSTANCE(float, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu similarity index 96% rename from cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu index 40c84d2a73..8c6adcf7d5 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu @@ -24,7 +24,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view dataset) \ ->index; -RAFT_MAKE_INSTANCE(int8_t, uint64_t); +RAFT_MAKE_INSTANCE(int8_t, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu similarity index 96% rename from cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu index 8d406542e8..19b6ea3705 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu @@ -24,7 +24,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view dataset) \ ->index; -RAFT_MAKE_INSTANCE(uint8_t, uint64_t); +RAFT_MAKE_INSTANCE(uint8_t, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu similarity index 98% rename from cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu index 3a0690a2f1..b416589520 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu @@ -30,7 +30,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view new_vectors, \ raft::device_matrix_view new_indices); -RAFT_MAKE_INSTANCE(float, uint64_t); +RAFT_MAKE_INSTANCE(float, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu index 83cb2d14e9..d4907b5c60 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu @@ -30,7 +30,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view new_vectors, \ raft::device_matrix_view new_indices); -RAFT_MAKE_INSTANCE(int8_t, uint64_t); +RAFT_MAKE_INSTANCE(int8_t, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu index 0b218dbc6f..64a206b54d 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu @@ -30,7 +30,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view new_vectors, \ raft::device_matrix_view new_indices); -RAFT_MAKE_INSTANCE(uint8_t, uint64_t); +RAFT_MAKE_INSTANCE(uint8_t, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu index f28e854554..f530961c0d 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu @@ -27,7 +27,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); -RAFT_MAKE_INSTANCE(float, uint64_t); +RAFT_MAKE_INSTANCE(float, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu index 230001df75..249979763a 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu @@ -27,7 +27,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); -RAFT_MAKE_INSTANCE(int8_t, uint64_t); +RAFT_MAKE_INSTANCE(int8_t, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu similarity index 97% rename from cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu rename to cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu index c6ff5097dc..b63ac43a6b 100644 --- a/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu +++ b/cpp/src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu @@ -27,7 +27,7 @@ namespace raft::neighbors::ivf_pq { raft::device_matrix_view neighbors, \ raft::device_matrix_view distances); -RAFT_MAKE_INSTANCE(uint8_t, uint64_t); +RAFT_MAKE_INSTANCE(uint8_t, int64_t); #undef RAFT_MAKE_INSTANCE diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_int64_t_float.cu similarity index 68% rename from cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu rename to cpp/src/distance/neighbors/specializations/refine_d_int64_t_float.cu index c8b0e4c1c2..a40f428291 100644 --- a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu +++ b/cpp/src/distance/neighbors/specializations/refine_d_int64_t_float.cu @@ -17,13 +17,14 @@ #include namespace raft::neighbors { -template void refine( + +template void refine( raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, distance::DistanceType metric); } // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_int64_t_int8_t.cu similarity index 68% rename from cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu rename to cpp/src/distance/neighbors/specializations/refine_d_int64_t_int8_t.cu index b9e0f58ef6..26d3a7b455 100644 --- a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu +++ b/cpp/src/distance/neighbors/specializations/refine_d_int64_t_int8_t.cu @@ -18,13 +18,13 @@ namespace raft::neighbors { -template void refine( +template void refine( raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, distance::DistanceType metric); } // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_d_int64_t_uint8_t.cu similarity index 67% rename from cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu rename to cpp/src/distance/neighbors/specializations/refine_d_int64_t_uint8_t.cu index b473924741..9080eb2297 100644 --- a/cpp/src/distance/neighbors/specializations/refine_h_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/specializations/refine_d_int64_t_uint8_t.cu @@ -18,13 +18,13 @@ namespace raft::neighbors { -template void refine( +template void refine( raft::device_resources const& handle, - raft::host_matrix_view dataset, - raft::host_matrix_view queries, - raft::host_matrix_view neighbor_candidates, - raft::host_matrix_view indices, - raft::host_matrix_view distances, + raft::device_matrix_view dataset, + raft::device_matrix_view queries, + raft::device_matrix_view neighbor_candidates, + raft::device_matrix_view indices, + raft::device_matrix_view distances, distance::DistanceType metric); } // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu deleted file mode 100644 index 7e70ee5e29..0000000000 --- a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu deleted file mode 100644 index 53de106ef9..0000000000 --- a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu +++ /dev/null @@ -1,30 +0,0 @@ -/* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include - -namespace raft::neighbors { - -template void refine( - raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, - distance::DistanceType metric); - -} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu b/cpp/src/distance/neighbors/specializations/refine_h_int64_t_float.cu similarity index 65% rename from cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu rename to cpp/src/distance/neighbors/specializations/refine_h_int64_t_float.cu index 6bb1985d94..03e387f543 100644 --- a/cpp/src/distance/neighbors/specializations/refine_d_uint64_t_float.cu +++ b/cpp/src/distance/neighbors/specializations/refine_h_int64_t_float.cu @@ -18,13 +18,13 @@ namespace raft::neighbors { -template void refine( +template void refine( raft::device_resources const& handle, - raft::device_matrix_view dataset, - raft::device_matrix_view queries, - raft::device_matrix_view neighbor_candidates, - raft::device_matrix_view indices, - raft::device_matrix_view distances, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, distance::DistanceType metric); } // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_int64_t_int8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_int64_t_int8_t.cu new file mode 100644 index 0000000000..045532e6a1 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_int64_t_int8_t.cu @@ -0,0 +1,29 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/src/distance/neighbors/specializations/refine_h_int64_t_uint8_t.cu b/cpp/src/distance/neighbors/specializations/refine_h_int64_t_uint8_t.cu new file mode 100644 index 0000000000..e33dae8ce8 --- /dev/null +++ b/cpp/src/distance/neighbors/specializations/refine_h_int64_t_uint8_t.cu @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace raft::neighbors { + +template void refine( + raft::device_resources const& handle, + raft::host_matrix_view dataset, + raft::host_matrix_view queries, + raft::host_matrix_view neighbor_candidates, + raft::host_matrix_view indices, + raft::host_matrix_view distances, + distance::DistanceType metric); + +} // namespace raft::neighbors diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 26ec8ebf74..df8fa54116 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -259,9 +259,9 @@ if(BUILD_TESTS) test/neighbors/ann_ivf_flat/test_uint8_t_int64_t.cu test/neighbors/ann_ivf_pq/test_float_int64_t.cu test/neighbors/ann_ivf_pq/test_float_uint32_t.cu - test/neighbors/ann_ivf_pq/test_float_uint64_t.cu - test/neighbors/ann_ivf_pq/test_int8_t_uint64_t.cu - test/neighbors/ann_ivf_pq/test_uint8_t_uint64_t.cu + test/neighbors/ann_ivf_pq/test_float_int64_t.cu + test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu + test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu test/neighbors/knn.cu test/neighbors/fused_l2_knn.cu test/neighbors/haversine.cu diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu index db42b1ee6a..9859061d70 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_float_int64_t.cu @@ -20,9 +20,8 @@ namespace raft::neighbors::ivf_pq { using f32_f32_i64 = ivf_pq_test; -TEST_BUILD_SEARCH(f32_f32_i64) TEST_BUILD_EXTEND_SEARCH(f32_f32_i64) -INSTANTIATE(f32_f32_i64, - enum_variety_l2() + enum_variety_ip() + big_dims_small_lut() + enum_variety_l2sqrt()); +TEST_BUILD_SERIALIZE_SEARCH(f32_f32_i64) +INSTANTIATE(f32_f32_i64, defaults() + small_dims() + big_dims_moderate_lut()); } // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_uint64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu similarity index 79% rename from cpp/test/neighbors/ann_ivf_pq/test_int8_t_uint64_t.cu rename to cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu index 03514cc3c3..014e96a2db 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_int8_t_uint64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_int8_t_int64_t.cu @@ -18,10 +18,10 @@ namespace raft::neighbors::ivf_pq { -using f32_i08_u64 = ivf_pq_test; +using f32_i08_i64 = ivf_pq_test; -TEST_BUILD_SEARCH(f32_i08_u64) -TEST_BUILD_SERIALIZE_SEARCH(f32_i08_u64) -INSTANTIATE(f32_i08_u64, defaults() + big_dims() + var_k()); +TEST_BUILD_SEARCH(f32_i08_i64) +TEST_BUILD_SERIALIZE_SEARCH(f32_i08_i64) +INSTANTIATE(f32_i08_i64, defaults() + big_dims() + var_k()); } // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_float_uint64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu similarity index 77% rename from cpp/test/neighbors/ann_ivf_pq/test_float_uint64_t.cu rename to cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu index 219ebfb790..e949c2f7ed 100644 --- a/cpp/test/neighbors/ann_ivf_pq/test_float_uint64_t.cu +++ b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_int64_t.cu @@ -18,10 +18,10 @@ namespace raft::neighbors::ivf_pq { -using f32_f32_u64 = ivf_pq_test; +using f32_u08_i64 = ivf_pq_test; -TEST_BUILD_EXTEND_SEARCH(f32_f32_u64) -TEST_BUILD_SERIALIZE_SEARCH(f32_f32_u64) -INSTANTIATE(f32_f32_u64, defaults() + small_dims() + big_dims_moderate_lut()); +TEST_BUILD_SEARCH(f32_u08_i64) +TEST_BUILD_EXTEND_SEARCH(f32_u08_i64) +INSTANTIATE(f32_u08_i64, small_dims_per_cluster() + enum_variety()); } // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_uint64_t.cu b/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_uint64_t.cu deleted file mode 100644 index 729e99d22c..0000000000 --- a/cpp/test/neighbors/ann_ivf_pq/test_uint8_t_uint64_t.cu +++ /dev/null @@ -1,27 +0,0 @@ -/* - * Copyright (c) 2022, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "../ann_ivf_pq.cuh" - -namespace raft::neighbors::ivf_pq { - -using f32_u08_u64 = ivf_pq_test; - -TEST_BUILD_SEARCH(f32_u08_u64) -TEST_BUILD_EXTEND_SEARCH(f32_u08_u64) -INSTANTIATE(f32_u08_u64, small_dims_per_cluster() + enum_variety()); - -} // namespace raft::neighbors::ivf_pq diff --git a/cpp/test/neighbors/refine.cu b/cpp/test/neighbors/refine.cu index 174dce5a7f..f0e85e0428 100644 --- a/cpp/test/neighbors/refine.cu +++ b/cpp/test/neighbors/refine.cu @@ -107,26 +107,26 @@ class RefineTest : public ::testing::TestWithParam> { RefineHelper data; }; -const std::vector> inputs = - raft::util::itertools::product>( - {static_cast(137)}, - {static_cast(1000)}, - {static_cast(16)}, - {static_cast(1), static_cast(10), static_cast(33)}, - {static_cast(33)}, +const std::vector> inputs = + raft::util::itertools::product>( + {static_cast(137)}, + {static_cast(1000)}, + {static_cast(16)}, + {static_cast(1), static_cast(10), static_cast(33)}, + {static_cast(33)}, {raft::distance::DistanceType::L2Expanded, raft::distance::DistanceType::InnerProduct}, {false, true}); -typedef RefineTest RefineTestF; +typedef RefineTest RefineTestF; TEST_P(RefineTestF, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_uint8; +typedef RefineTest RefineTestF_uint8; TEST_P(RefineTestF_uint8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_uint8, ::testing::ValuesIn(inputs)); -typedef RefineTest RefineTestF_int8; +typedef RefineTest RefineTestF_int8; TEST_P(RefineTestF_int8, AnnRefine) { this->testRefine(); } INSTANTIATE_TEST_CASE_P(RefineTest, RefineTestF_int8, ::testing::ValuesIn(inputs)); } // namespace raft::neighbors diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 2a0bdaca62..98521e48fa 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -19,21 +19,21 @@ # cython: embedsignature = True # cython: language_level = 3 -from libc.stdint cimport int8_t, uint8_t, uint64_t +from libc.stdint cimport int8_t, int64_t, uint8_t from libcpp.string cimport string from pylibraft.common.cpp.mdspan cimport device_matrix_view, row_major from pylibraft.common.handle cimport device_resources -cdef device_matrix_view[float, uint64_t, row_major] get_dmv_float( +cdef device_matrix_view[float, int64_t, row_major] get_dmv_float( array, check_shape) except * -cdef device_matrix_view[uint8_t, uint64_t, row_major] get_dmv_uint8( +cdef device_matrix_view[uint8_t, int64_t, row_major] get_dmv_uint8( array, check_shape) except * -cdef device_matrix_view[int8_t, uint64_t, row_major] get_dmv_int8( +cdef device_matrix_view[int8_t, int64_t, row_major] get_dmv_int8( array, check_shape) except * -cdef device_matrix_view[uint64_t, uint64_t, row_major] get_dmv_uint64( +cdef device_matrix_view[int64_t, int64_t, row_major] get_dmv_int64( array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 22afda043d..9f04545a0f 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -25,15 +25,7 @@ import numpy as np from cpython.object cimport PyObject from cython.operator cimport dereference as deref from libc.stddef cimport size_t -from libc.stdint cimport ( - int8_t, - int32_t, - int64_t, - uint8_t, - uint32_t, - uint64_t, - uintptr_t, -) +from libc.stdint cimport int8_t, int32_t, int64_t, uint8_t, uint32_t, uintptr_t from pylibraft.common.cpp.mdspan cimport ( col_major, @@ -156,45 +148,45 @@ def run_roundtrip_test_for_mdspan(X, fortran_order=False): assert np.all(X == X2) -cdef device_matrix_view[float, uint64_t, row_major] \ +cdef device_matrix_view[float, int64_t, row_major] \ get_dmv_float(cai, check_shape) except *: if cai.dtype != np.float32: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[float, uint64_t, row_major]( + return make_device_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) -cdef device_matrix_view[uint8_t, uint64_t, row_major] \ +cdef device_matrix_view[uint8_t, int64_t, row_major] \ get_dmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[uint8_t, uint64_t, row_major]( + return make_device_matrix_view[uint8_t, int64_t, row_major]( cai.data, shape[0], shape[1]) -cdef device_matrix_view[int8_t, uint64_t, row_major] \ +cdef device_matrix_view[int8_t, int64_t, row_major] \ get_dmv_int8(cai, check_shape) except *: if cai.dtype != np.int8: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[int8_t, uint64_t, row_major]( + return make_device_matrix_view[int8_t, int64_t, row_major]( cai.data, shape[0], shape[1]) -cdef device_matrix_view[uint64_t, uint64_t, row_major] \ - get_dmv_uint64(cai, check_shape) except *: - if cai.dtype != np.uint64: +cdef device_matrix_view[int64_t, int64_t, row_major] \ + get_dmv_int64(cai, check_shape) except *: + if cai.dtype != np.int64: raise TypeError("dtype %s not supported" % cai.dtype) if check_shape and len(cai.shape) != 2: raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) - return make_device_matrix_view[uint64_t, uint64_t, row_major]( - cai.data, shape[0], shape[1]) + return make_device_matrix_view[int64_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index ca35f5b8ca..dcc0371421 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -23,14 +23,7 @@ import numpy as np import pylibraft.common.handle from cython.operator cimport dereference as deref -from libc.stdint cimport ( - int8_t, - int64_t, - uint8_t, - uint32_t, - uint64_t, - uintptr_t, -) +from libc.stdint cimport int8_t, int64_t, uint8_t, uint32_t, uintptr_t from libcpp cimport bool, nullptr from libcpp.string cimport string @@ -114,70 +107,70 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ cdef void build( const device_resources& handle, const index_params& params, - device_matrix_view[float, uint64_t, row_major] dataset, - index[uint64_t]* index) except + + device_matrix_view[float, int64_t, row_major] dataset, + index[int64_t]* index) except + cdef void build( const device_resources& handle, const index_params& params, - device_matrix_view[int8_t, uint64_t, row_major] dataset, - index[uint64_t]* index) except + + device_matrix_view[int8_t, int64_t, row_major] dataset, + index[int64_t]* index) except + cdef void build( const device_resources& handle, const index_params& params, - device_matrix_view[uint8_t, uint64_t, row_major] dataset, - index[uint64_t]* index) except + + device_matrix_view[uint8_t, int64_t, row_major] dataset, + index[int64_t]* index) except + cdef void extend( const device_resources& handle, - index[uint64_t]* index, - device_matrix_view[float, uint64_t, row_major] new_vectors, - device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + index[int64_t]* index, + device_matrix_view[float, int64_t, row_major] new_vectors, + device_matrix_view[int64_t, int64_t, row_major] new_indices) except + cdef void extend( const device_resources& handle, - index[uint64_t]* index, - device_matrix_view[int8_t, uint64_t, row_major] new_vectors, - device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + index[int64_t]* index, + device_matrix_view[int8_t, int64_t, row_major] new_vectors, + device_matrix_view[int64_t, int64_t, row_major] new_indices) except + cdef void extend( const device_resources& handle, - index[uint64_t]* index, - device_matrix_view[uint8_t, uint64_t, row_major] new_vectors, - device_matrix_view[uint64_t, uint64_t, row_major] new_indices) except + + index[int64_t]* index, + device_matrix_view[uint8_t, int64_t, row_major] new_vectors, + device_matrix_view[int64_t, int64_t, row_major] new_indices) except + cdef void search( const device_resources& handle, const search_params& params, - const index[uint64_t]& index, - device_matrix_view[float, uint64_t, row_major] queries, + const index[int64_t]& index, + device_matrix_view[float, int64_t, row_major] queries, uint32_t k, - device_matrix_view[uint64_t, uint64_t, row_major] neighbors, - device_matrix_view[float, uint64_t, row_major] distances) except + + device_matrix_view[int64_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + cdef void search( const device_resources& handle, const search_params& params, - const index[uint64_t]& index, - device_matrix_view[int8_t, uint64_t, row_major] queries, + const index[int64_t]& index, + device_matrix_view[int8_t, int64_t, row_major] queries, uint32_t k, - device_matrix_view[uint64_t, uint64_t, row_major] neighbors, - device_matrix_view[float, uint64_t, row_major] distances) except + + device_matrix_view[int64_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + cdef void search( const device_resources& handle, const search_params& params, - const index[uint64_t]& index, - device_matrix_view[uint8_t, uint64_t, row_major] queries, + const index[int64_t]& index, + device_matrix_view[uint8_t, int64_t, row_major] queries, uint32_t k, - device_matrix_view[uint64_t, uint64_t, row_major] neighbors, - device_matrix_view[float, uint64_t, row_major] distances) except + + device_matrix_view[int64_t, int64_t, row_major] neighbors, + device_matrix_view[float, int64_t, row_major] distances) except + cdef void serialize(const device_resources& handle, const string& filename, - const index[uint64_t]& index) except + + const index[int64_t]& index) except + cdef void deserialize(const device_resources& handle, const string& filename, - index[uint64_t]* index) except + + index[int64_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 47d8e94e5f..860a1ea27c 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -23,7 +23,7 @@ import warnings import numpy as np from cython.operator cimport dereference as deref -from libc.stdint cimport int32_t, int64_t, uint32_t, uint64_t, uintptr_t +from libc.stdint cimport int32_t, int64_t, uint32_t, uintptr_t from libcpp cimport bool, nullptr from libcpp.string cimport string @@ -54,8 +54,8 @@ from pylibraft.common.cpp.mdspan cimport device_matrix_view from pylibraft.common.mdspan cimport ( get_dmv_float, get_dmv_int8, + get_dmv_int64, get_dmv_uint8, - get_dmv_uint64, ) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, @@ -259,7 +259,7 @@ cdef class IndexParams: cdef class Index: # We store a pointer to the index because it dose not have a trivial # constructor. - cdef c_ivf_pq.index[uint64_t] * index + cdef c_ivf_pq.index[int64_t] * index cdef readonly bool trained def __cinit__(self, handle=None): @@ -272,7 +272,7 @@ cdef class Index: # We create a placeholder object. The actual parameter values do # not matter, it will be replaced with a built index object later. - self.index = new c_ivf_pq.index[uint64_t]( + self.index = new c_ivf_pq.index[int64_t]( deref(handle_), _get_metric("sqeuclidean"), c_ivf_pq.codebook_gen.PER_SUBSPACE, 1, @@ -396,7 +396,7 @@ def build(IndexParams index_params, dataset, handle=None): _check_input_array(dataset_cai, [np.dtype('float32'), np.dtype('byte'), np.dtype('ubyte')]) - cdef uint64_t n_rows = dataset_cai.shape[0] + cdef int64_t n_rows = dataset_cai.shape[0] cdef uint32_t dim = dataset_cai.shape[1] if handle is None: @@ -449,7 +449,7 @@ def extend(Index index, new_vectors, new_indices, handle=None): new_vectors : array interface compliant matrix shape (n_samples, dim) Supported dtype [float, int8, uint8] new_indices : array interface compliant matrix shape (n_samples, dim) - Supported dtype [uint64] + Supported dtype [int64] {handle_docstring} Returns @@ -476,7 +476,7 @@ def extend(Index index, new_vectors, new_indices, handle=None): >>> n_rows = 100 >>> more_data = cp.random.random_sample((n_rows, n_features), ... dtype=cp.float32) - >>> indices = index.size + cp.arange(n_rows, dtype=cp.uint64) + >>> indices = index.size + cp.arange(n_rows, dtype=cp.int64) >>> index = ivf_pq.extend(index, more_data, indices) >>> # Search using the built index @@ -504,7 +504,7 @@ def extend(Index index, new_vectors, new_indices, handle=None): vecs_cai = wrap_array(new_vectors) vecs_dt = vecs_cai.dtype - cdef uint64_t n_rows = vecs_cai.shape[0] + cdef int64_t n_rows = vecs_cai.shape[0] cdef uint32_t dim = vecs_cai.shape[1] _check_input_array(vecs_cai, [np.dtype('float32'), np.dtype('byte'), @@ -512,7 +512,7 @@ def extend(Index index, new_vectors, new_indices, handle=None): exp_cols=index.dim) idx_cai = wrap_array(new_indices) - _check_input_array(idx_cai, [np.dtype('uint64')], exp_rows=n_rows) + _check_input_array(idx_cai, [np.dtype('int64')], exp_rows=n_rows) if len(idx_cai.shape)!=1: raise ValueError("Indices array is expected to be 1D") @@ -521,19 +521,19 @@ def extend(Index index, new_vectors, new_indices, handle=None): c_ivf_pq.extend(deref(handle_), index.index, get_dmv_float(vecs_cai, check_shape=True), - get_dmv_uint64(idx_cai, check_shape=False)) + get_dmv_int64(idx_cai, check_shape=False)) elif vecs_dt == np.int8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, get_dmv_int8(vecs_cai, check_shape=True), - get_dmv_uint64(idx_cai, check_shape=False)) + get_dmv_int64(idx_cai, check_shape=False)) elif vecs_dt == np.uint8: with cuda_interruptible(): c_ivf_pq.extend(deref(handle_), index.index, get_dmv_uint8(vecs_cai, check_shape=True), - get_dmv_uint64(idx_cai, check_shape=False)) + get_dmv_int64(idx_cai, check_shape=False)) else: raise TypeError("query dtype %s not supported" % vecs_dt) @@ -618,7 +618,7 @@ def search(SearchParams search_params, k : int The number of neighbors. neighbors : Optional CUDA array interface compliant matrix shape - (n_queries, k), dtype uint64_t. If supplied, neighbor + (n_queries, k), dtype int64_t. If supplied, neighbor indices will be written here in-place. (default None) distances : Optional CUDA array interface compliant matrix shape (n_queries, k) If supplied, the distances to the @@ -695,10 +695,10 @@ def search(SearchParams search_params, exp_cols=index.dim) if neighbors is None: - neighbors = device_ndarray.empty((n_queries, k), dtype='uint64') + neighbors = device_ndarray.empty((n_queries, k), dtype='int64') neighbors_cai = cai_wrapper(neighbors) - _check_input_array(neighbors_cai, [np.dtype('uint64')], + _check_input_array(neighbors_cai, [np.dtype('int64')], exp_rows=n_queries, exp_cols=k) if distances is None: @@ -724,7 +724,7 @@ def search(SearchParams search_params, deref(index.index), get_dmv_float(queries_cai, check_shape=True), k, - get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_int64(neighbors_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.byte: with cuda_interruptible(): @@ -733,7 +733,7 @@ def search(SearchParams search_params, deref(index.index), get_dmv_int8(queries_cai, check_shape=True), k, - get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_int64(neighbors_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True)) elif queries_dt == np.ubyte: with cuda_interruptible(): @@ -742,7 +742,7 @@ def search(SearchParams search_params, deref(index.index), get_dmv_uint8(queries_cai, check_shape=True), k, - get_dmv_uint64(neighbors_cai, check_shape=True), + get_dmv_int64(neighbors_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True)) else: raise ValueError("query dtype %s not supported" % queries_dt) diff --git a/python/pylibraft/pylibraft/neighbors/refine.pyx b/python/pylibraft/pylibraft/neighbors/refine.pyx index ddc6f115a3..8eb468c805 100644 --- a/python/pylibraft/pylibraft/neighbors/refine.pyx +++ b/python/pylibraft/pylibraft/neighbors/refine.pyx @@ -21,7 +21,7 @@ import numpy as np from cython.operator cimport dereference as deref -from libc.stdint cimport int8_t, int64_t, uint8_t, uint64_t, uintptr_t +from libc.stdint cimport int8_t, int64_t, uint8_t, uintptr_t from libcpp cimport bool, nullptr from pylibraft.distance.distance_type cimport DistanceType @@ -54,8 +54,8 @@ from pylibraft.common.cpp.mdspan cimport ( from pylibraft.common.mdspan cimport ( get_dmv_float, get_dmv_int8, + get_dmv_int64, get_dmv_uint8, - get_dmv_uint64, ) from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( index_params, @@ -70,56 +70,56 @@ cdef extern from "raft_runtime/neighbors/refine.hpp" \ cdef void c_refine "raft::runtime::neighbors::refine" ( const device_resources& handle, - device_matrix_view[float, uint64_t, row_major] dataset, - device_matrix_view[float, uint64_t, row_major] queries, - device_matrix_view[uint64_t, uint64_t, row_major] candidates, - device_matrix_view[uint64_t, uint64_t, row_major] indices, - device_matrix_view[float, uint64_t, row_major] distances, + device_matrix_view[float, int64_t, row_major] dataset, + device_matrix_view[float, int64_t, row_major] queries, + device_matrix_view[int64_t, int64_t, row_major] candidates, + device_matrix_view[int64_t, int64_t, row_major] indices, + device_matrix_view[float, int64_t, row_major] distances, DistanceType metric) except + cdef void c_refine "raft::runtime::neighbors::refine" ( const device_resources& handle, - device_matrix_view[uint8_t, uint64_t, row_major] dataset, - device_matrix_view[uint8_t, uint64_t, row_major] queries, - device_matrix_view[uint64_t, uint64_t, row_major] candidates, - device_matrix_view[uint64_t, uint64_t, row_major] indices, - device_matrix_view[float, uint64_t, row_major] distances, + device_matrix_view[uint8_t, int64_t, row_major] dataset, + device_matrix_view[uint8_t, int64_t, row_major] queries, + device_matrix_view[int64_t, int64_t, row_major] candidates, + device_matrix_view[int64_t, int64_t, row_major] indices, + device_matrix_view[float, int64_t, row_major] distances, DistanceType metric) except + cdef void c_refine "raft::runtime::neighbors::refine" ( const device_resources& handle, - device_matrix_view[int8_t, uint64_t, row_major] dataset, - device_matrix_view[int8_t, uint64_t, row_major] queries, - device_matrix_view[uint64_t, uint64_t, row_major] candidates, - device_matrix_view[uint64_t, uint64_t, row_major] indices, - device_matrix_view[float, uint64_t, row_major] distances, + device_matrix_view[int8_t, int64_t, row_major] dataset, + device_matrix_view[int8_t, int64_t, row_major] queries, + device_matrix_view[int64_t, int64_t, row_major] candidates, + device_matrix_view[int64_t, int64_t, row_major] indices, + device_matrix_view[float, int64_t, row_major] distances, DistanceType metric) except + cdef void c_refine "raft::runtime::neighbors::refine" ( const device_resources& handle, - host_matrix_view[float, uint64_t, row_major] dataset, - host_matrix_view[float, uint64_t, row_major] queries, - host_matrix_view[uint64_t, uint64_t, row_major] candidates, - host_matrix_view[uint64_t, uint64_t, row_major] indices, - host_matrix_view[float, uint64_t, row_major] distances, + host_matrix_view[float, int64_t, row_major] dataset, + host_matrix_view[float, int64_t, row_major] queries, + host_matrix_view[int64_t, int64_t, row_major] candidates, + host_matrix_view[int64_t, int64_t, row_major] indices, + host_matrix_view[float, int64_t, row_major] distances, DistanceType metric) except + cdef void c_refine "raft::runtime::neighbors::refine" ( const device_resources& handle, - host_matrix_view[uint8_t, uint64_t, row_major] dataset, - host_matrix_view[uint8_t, uint64_t, row_major] queries, - host_matrix_view[uint64_t, uint64_t, row_major] candidates, - host_matrix_view[uint64_t, uint64_t, row_major] indices, - host_matrix_view[float, uint64_t, row_major] distances, + host_matrix_view[uint8_t, int64_t, row_major] dataset, + host_matrix_view[uint8_t, int64_t, row_major] queries, + host_matrix_view[int64_t, int64_t, row_major] candidates, + host_matrix_view[int64_t, int64_t, row_major] indices, + host_matrix_view[float, int64_t, row_major] distances, DistanceType metric) except + cdef void c_refine "raft::runtime::neighbors::refine" ( const device_resources& handle, - host_matrix_view[int8_t, uint64_t, row_major] dataset, - host_matrix_view[int8_t, uint64_t, row_major] queries, - host_matrix_view[uint64_t, uint64_t, row_major] candidates, - host_matrix_view[uint64_t, uint64_t, row_major] indices, - host_matrix_view[float, uint64_t, row_major] distances, + host_matrix_view[int8_t, int64_t, row_major] dataset, + host_matrix_view[int8_t, int64_t, row_major] queries, + host_matrix_view[int64_t, int64_t, row_major] candidates, + host_matrix_view[int64_t, int64_t, row_major] indices, + host_matrix_view[float, int64_t, row_major] distances, DistanceType metric) except + @@ -134,35 +134,35 @@ def _get_array_params(array_interface, check_dtype=None): return (shape, dtype, data) -cdef host_matrix_view[float, uint64_t, row_major] \ +cdef host_matrix_view[float, int64_t, row_major] \ get_host_matrix_view_float(array) except *: shape, dtype, data = _get_array_params( array.__array_interface__, check_dtype=np.float32) - return make_host_matrix_view[float, uint64_t, row_major]( + return make_host_matrix_view[float, int64_t, row_major]( data, shape[0], shape[1]) -cdef host_matrix_view[uint64_t, uint64_t, row_major] \ - get_host_matrix_view_uint64(array) except *: +cdef host_matrix_view[int64_t, int64_t, row_major] \ + get_host_matrix_view_int64_t(array) except *: shape, dtype, data = _get_array_params( - array.__array_interface__, check_dtype=np.uint64) - return make_host_matrix_view[uint64_t, uint64_t, row_major]( - data, shape[0], shape[1]) + array.__array_interface__, check_dtype=np.int64) + return make_host_matrix_view[int64_t, int64_t, row_major]( + data, shape[0], shape[1]) -cdef host_matrix_view[uint8_t, uint64_t, row_major] \ +cdef host_matrix_view[uint8_t, int64_t, row_major] \ get_host_matrix_view_uint8(array) except *: shape, dtype, data = _get_array_params( array.__array_interface__, check_dtype=np.uint8) - return make_host_matrix_view[uint8_t, uint64_t, row_major]( + return make_host_matrix_view[uint8_t, int64_t, row_major]( data, shape[0], shape[1]) -cdef host_matrix_view[int8_t, uint64_t, row_major] \ +cdef host_matrix_view[int8_t, int64_t, row_major] \ get_host_matrix_view_int8(array) except *: shape, dtype, data = _get_array_params( array.__array_interface__, check_dtype=np.int8) - return make_host_matrix_view[int8_t, uint64_t, row_major]( + return make_host_matrix_view[int8_t, int64_t, row_major]( data, shape[0], shape[1]) @@ -191,15 +191,15 @@ def refine(dataset, queries, candidates, k=None, indices=None, distances=None, queries : array interface compliant matrix, shape (n_queries, dim) Supported dtype [float, int8, uint8] candidates : array interface compliant matrix, shape (n_queries, k0) - dtype uint64 + dtype int64 k : int Number of neighbors to search (k <= k0). Optional if indices or distances arrays are given (in which case their second dimension is k). indices : Optional array interface compliant matrix shape - (n_queries, k), dtype uint64. If supplied, neighbor + (n_queries, k), dtype int64. If supplied, neighbor indices will be written here in-place. (default None) - Supported dtype uint64 + Supported dtype int64 distances : Optional array interface compliant matrix shape (n_queries, k), dtype float. If supplied, neighbor indices will be written here in-place. (default None) @@ -278,7 +278,7 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, n_queries = cai_wrapper(queries).shape[0] if indices is None: - indices = device_ndarray.empty((n_queries, k), dtype='uint64') + indices = device_ndarray.empty((n_queries, k), dtype='int64') if distances is None: distances = device_ndarray.empty((n_queries, k), dtype='float32') @@ -293,8 +293,8 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, c_refine(deref(handle_), get_dmv_float(dataset_cai, check_shape=True), get_dmv_float(queries_cai, check_shape=True), - get_dmv_uint64(candidates_cai, check_shape=True), - get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_int64(candidates_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True), c_metric) elif dataset_cai.dtype == np.int8: @@ -302,8 +302,8 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, c_refine(deref(handle_), get_dmv_int8(dataset_cai, check_shape=True), get_dmv_int8(queries_cai, check_shape=True), - get_dmv_uint64(candidates_cai, check_shape=True), - get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_int64(candidates_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True), c_metric) elif dataset_cai.dtype == np.uint8: @@ -311,8 +311,8 @@ def _refine_device(dataset, queries, candidates, k, indices, distances, c_refine(deref(handle_), get_dmv_uint8(dataset_cai, check_shape=True), get_dmv_uint8(queries_cai, check_shape=True), - get_dmv_uint64(candidates_cai, check_shape=True), - get_dmv_uint64(indices_cai, check_shape=True), + get_dmv_int64(candidates_cai, check_shape=True), + get_dmv_int64(indices_cai, check_shape=True), get_dmv_float(distances_cai, check_shape=True), c_metric) else: @@ -338,7 +338,7 @@ def _refine_host(dataset, queries, candidates, k, indices, distances, n_queries = queries.__array_interface__["shape"][0] if indices is None: - indices = np.empty((n_queries, k), dtype='uint64') + indices = np.empty((n_queries, k), dtype='int64') if distances is None: distances = np.empty((n_queries, k), dtype='float32') @@ -352,8 +352,8 @@ def _refine_host(dataset, queries, candidates, k, indices, distances, c_refine(deref(handle_), get_host_matrix_view_float(dataset), get_host_matrix_view_float(queries), - get_host_matrix_view_uint64(candidates), - get_host_matrix_view_uint64(indices), + get_host_matrix_view_int64_t(candidates), + get_host_matrix_view_int64_t(indices), get_host_matrix_view_float(distances), c_metric) elif dtype == np.int8: @@ -361,8 +361,8 @@ def _refine_host(dataset, queries, candidates, k, indices, distances, c_refine(deref(handle_), get_host_matrix_view_int8(dataset), get_host_matrix_view_int8(queries), - get_host_matrix_view_uint64(candidates), - get_host_matrix_view_uint64(indices), + get_host_matrix_view_int64_t(candidates), + get_host_matrix_view_int64_t(indices), get_host_matrix_view_float(distances), c_metric) elif dtype == np.uint8: @@ -370,8 +370,8 @@ def _refine_host(dataset, queries, candidates, k, indices, distances, c_refine(deref(handle_), get_host_matrix_view_uint8(dataset), get_host_matrix_view_uint8(queries), - get_host_matrix_view_uint64(candidates), - get_host_matrix_view_uint64(indices), + get_host_matrix_view_int64_t(candidates), + get_host_matrix_view_int64_t(indices), get_host_matrix_view_float(distances), c_metric) else: diff --git a/python/pylibraft/pylibraft/test/test_ivf_pq.py b/python/pylibraft/pylibraft/test/test_ivf_pq.py index 54c9ec716d..aa58e2a8fc 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_pq.py +++ b/python/pylibraft/pylibraft/test/test_ivf_pq.py @@ -128,8 +128,8 @@ def run_ivf_pq_build_search_test( if not add_data_on_build: dataset_1 = dataset[: n_rows // 2, :] dataset_2 = dataset[n_rows // 2 :, :] - indices_1 = np.arange(n_rows // 2, dtype=np.uint64) - indices_2 = np.arange(n_rows // 2, n_rows, dtype=np.uint64) + indices_1 = np.arange(n_rows // 2, dtype=np.int64) + indices_2 = np.arange(n_rows // 2, n_rows, dtype=np.int64) if array_type == "device": dataset_1_device = device_ndarray(dataset_1) dataset_2_device = device_ndarray(dataset_2) @@ -144,7 +144,7 @@ def run_ivf_pq_build_search_test( assert index.size >= n_rows queries = generate_data((n_queries, n_cols), dtype) - out_idx = np.zeros((n_queries, k), dtype=np.uint64) + out_idx = np.zeros((n_queries, k), dtype=np.int64) out_dist = np.zeros((n_queries, k), dtype=np.float32) queries_device = device_ndarray(queries) @@ -397,7 +397,7 @@ def test_build_assertions(): index = ivf_pq.Index() queries = generate_data((n_queries, n_cols), np.float32) - out_idx = np.zeros((n_queries, k), dtype=np.uint64) + out_idx = np.zeros((n_queries, k), dtype=np.int64) out_dist = np.zeros((n_queries, k), dtype=np.float32) queries_device = device_ndarray(queries) @@ -420,7 +420,7 @@ def test_build_assertions(): index = ivf_pq.build(index_params, dataset_device) assert index.trained - indices = np.arange(n_rows + 1, dtype=np.uint64) + indices = np.arange(n_rows + 1, dtype=np.int64) indices_device = device_ndarray(indices) with pytest.raises(ValueError): @@ -463,7 +463,7 @@ def test_search_inputs(params): ).astype(q_dt, order=q_order) queries_device = device_ndarray(queries) - idx_dt = params.get("idx_dt", np.uint64) + idx_dt = params.get("idx_dt", np.int64) idx_order = params.get("idx_order", "C") out_idx = np.zeros( (params.get("idx_rows", n_queries), params.get("idx_cols", k)), diff --git a/python/pylibraft/pylibraft/test/test_refine.py b/python/pylibraft/pylibraft/test/test_refine.py index 8502d0575c..397ea70ec7 100644 --- a/python/pylibraft/pylibraft/test/test_refine.py +++ b/python/pylibraft/pylibraft/test/test_refine.py @@ -57,10 +57,10 @@ def run_refine( ) nn_skl.fit(dataset) skl_dist, candidates = nn_skl.kneighbors(queries) - candidates = candidates.astype(np.uint64) + candidates = candidates.astype(np.int64) candidates_device = device_ndarray(candidates) - out_idx = np.zeros((n_queries, k), dtype=np.uint64) + out_idx = np.zeros((n_queries, k), dtype=np.int64) out_dist = np.zeros((n_queries, k), dtype=np.float32) out_idx_device = device_ndarray(out_idx) if inplace else None out_dist_device = device_ndarray(out_dist) if inplace else None @@ -196,12 +196,12 @@ def test_input_assertions(params, memory_type): queries_device = device_ndarray(queries) candidates = np.random.randint( - 0, 500, size=(n_queries, k0), dtype=np.uint64 + 0, 500, size=(n_queries, k0), dtype=np.int64 ) candidates_device = device_ndarray(candidates) if params["idx_shape"] is not None: - out_idx = np.zeros(params["idx_shape"], dtype=np.uint64) + out_idx = np.zeros(params["idx_shape"], dtype=np.int64) out_idx_device = device_ndarray(out_idx) else: out_idx_device = None