Skip to content

Commit

Permalink
Using int64_t specializations for ivf_pq and refine (#1325)
Browse files Browse the repository at this point in the history
Since FAISS and Milvus are both using `int64_t` everywhere and we haven't explicitly encountered places where `uint64_t` is being used for ann indices, we're opting to just use `int64_t` instead.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: #1325
  • Loading branch information
cjnolet authored Mar 10, 2023
1 parent 362dc93 commit 6b30de9
Show file tree
Hide file tree
Showing 57 changed files with 441 additions and 481 deletions.
48 changes: 24 additions & 24 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -299,18 +299,18 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/cluster/update_centroids_double.cu
src/distance/cluster/cluster_cost_float.cu
src/distance/cluster/cluster_cost_double.cu
src/distance/neighbors/refine_d_uint64_t_float.cu
src/distance/neighbors/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/refine_h_uint64_t_float.cu
src/distance/neighbors/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_float.cu
src/distance/neighbors/specializations/refine_d_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_d_uint64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_float.cu
src/distance/neighbors/specializations/refine_h_uint64_t_int8_t.cu
src/distance/neighbors/specializations/refine_h_uint64_t_uint8_t.cu
src/distance/neighbors/refine_d_int64_t_float.cu
src/distance/neighbors/refine_d_int64_t_int8_t.cu
src/distance/neighbors/refine_d_int64_t_uint8_t.cu
src/distance/neighbors/refine_h_int64_t_float.cu
src/distance/neighbors/refine_h_int64_t_int8_t.cu
src/distance/neighbors/refine_h_int64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_d_int64_t_float.cu
src/distance/neighbors/specializations/refine_d_int64_t_int8_t.cu
src/distance/neighbors/specializations/refine_d_int64_t_uint8_t.cu
src/distance/neighbors/specializations/refine_h_int64_t_float.cu
src/distance/neighbors/specializations/refine_h_int64_t_int8_t.cu
src/distance/neighbors/specializations/refine_h_int64_t_uint8_t.cu
src/distance/cluster/kmeans_fit_float.cu
src/distance/cluster/kmeans_fit_double.cu
src/distance/cluster/kmeans_init_plus_plus_double.cu
Expand Down Expand Up @@ -367,18 +367,18 @@ if(RAFT_COMPILE_DIST_LIBRARY)
src/distance/neighbors/ivfpq_build.cu
src/distance/neighbors/ivfpq_deserialize.cu
src/distance/neighbors/ivfpq_serialize.cu
src/distance/neighbors/ivfpq_search_float_uint64_t.cu
src/distance/neighbors/ivfpq_search_int8_t_uint64_t.cu
src/distance/neighbors/ivfpq_search_uint8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_float_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_build_uint8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_extend_float_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_extend_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_extend_uint8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_search_float_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_search_int8_t_uint64_t.cu
src/distance/neighbors/specializations/ivfpq_search_uint8_t_uint64_t.cu
src/distance/neighbors/ivfpq_search_float_int64_t.cu
src/distance/neighbors/ivfpq_search_int8_t_int64_t.cu
src/distance/neighbors/ivfpq_search_uint8_t_int64_t.cu
src/distance/neighbors/specializations/ivfpq_build_float_int64_t.cu
src/distance/neighbors/specializations/ivfpq_build_int8_t_int64_t.cu
src/distance/neighbors/specializations/ivfpq_build_uint8_t_int64_t.cu
src/distance/neighbors/specializations/ivfpq_extend_float_int64_t.cu
src/distance/neighbors/specializations/ivfpq_extend_int8_t_int64_t.cu
src/distance/neighbors/specializations/ivfpq_extend_uint8_t_int64_t.cu
src/distance/neighbors/specializations/ivfpq_search_float_int64_t.cu
src/distance/neighbors/specializations/ivfpq_search_int8_t_int64_t.cu
src/distance/neighbors/specializations/ivfpq_search_uint8_t_int64_t.cu
src/distance/neighbors/specializations/detail/compute_similarity_float_float_fast.cu
src/distance/neighbors/specializations/detail/compute_similarity_float_float_no_basediff.cu
src/distance/neighbors/specializations/detail/compute_similarity_float_float_no_smem_lut.cu
Expand Down
10 changes: 5 additions & 5 deletions cpp/bench/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,11 @@ if(BUILD_BENCH)
bench/neighbors/knn/ivf_flat_float_int64_t.cu
bench/neighbors/knn/ivf_flat_int8_t_int64_t.cu
bench/neighbors/knn/ivf_flat_uint8_t_int64_t.cu
bench/neighbors/knn/ivf_pq_float_uint64_t.cu
bench/neighbors/knn/ivf_pq_int8_t_uint64_t.cu
bench/neighbors/knn/ivf_pq_uint8_t_uint64_t.cu
bench/neighbors/refine_float_uint64_t.cu
bench/neighbors/refine_uint8_t_uint64_t.cu
bench/neighbors/knn/ivf_pq_float_int64_t.cu
bench/neighbors/knn/ivf_pq_int8_t_int64_t.cu
bench/neighbors/knn/ivf_pq_uint8_t_int64_t.cu
bench/neighbors/refine_float_int64_t.cu
bench/neighbors/refine_uint8_t_int64_t.cu
bench/main.cpp
OPTIONAL
DIST
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

namespace raft::bench::spatial {

KNN_REGISTER(uint8_t, uint32_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(float, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes);

} // namespace raft::bench::spatial
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

namespace raft::bench::spatial {

KNN_REGISTER(float, uint64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(int8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes);

} // namespace raft::bench::spatial
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,6 @@

namespace raft::bench::spatial {

KNN_REGISTER(int8_t, uint64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes);
KNN_REGISTER(uint8_t, int64_t, ivf_pq_knn, kInputs, kNoCopyOnly, kAllScopes);

} // namespace raft::bench::spatial
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@
using namespace raft::neighbors;

namespace raft::bench::neighbors {
using refine_float_int64 = RefineAnn<float, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs<uint64_t>());
using refine_float_int64 = RefineAnn<float, float, int64_t>;
RAFT_BENCH_REGISTER(refine_float_int64, "", getInputs<int64_t>());
} // namespace raft::bench::neighbors
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@
using namespace raft::neighbors;

namespace raft::bench::neighbors {
using refine_uint8_int64 = RefineAnn<uint8_t, float, uint64_t>;
RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs<uint64_t>());
using refine_uint8_int64 = RefineAnn<uint8_t, float, int64_t>;
RAFT_BENCH_REGISTER(refine_uint8_int64, "", getInputs<int64_t>());
} // namespace raft::bench::neighbors
1 change: 1 addition & 0 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/neighbors/ivf_pq_serialize.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>

#include <rmm/cuda_stream_view.hpp>
Expand Down
6 changes: 3 additions & 3 deletions cpp/include/raft/neighbors/specializations/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ namespace raft::neighbors::ivf_pq {
IdxT*, \
float*, \
rmm::mr::device_memory_resource*);
RAFT_INST(float, uint64_t);
RAFT_INST(int8_t, uint64_t);
RAFT_INST(uint8_t, uint64_t);
RAFT_INST(float, int64_t);
RAFT_INST(int8_t, int64_t);
RAFT_INST(uint8_t, int64_t);

#undef RAFT_INST

Expand Down
40 changes: 20 additions & 20 deletions cpp/include/raft/neighbors/specializations/refine.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,28 @@ namespace raft::neighbors {
#undef RAFT_INST
#endif

#define RAFT_INST(T, IdxT) \
extern template void refine<IdxT, T, float, uint64_t>( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset, \
raft::device_matrix_view<const T, uint64_t, row_major> queries, \
raft::device_matrix_view<const IdxT, uint64_t, row_major> neighbor_candidates, \
raft::device_matrix_view<IdxT, uint64_t, row_major> indices, \
raft::device_matrix_view<float, uint64_t, row_major> distances, \
distance::DistanceType metric); \
\
extern template void refine<IdxT, T, float, uint64_t>( \
raft::device_resources const& handle, \
raft::host_matrix_view<const T, uint64_t, row_major> dataset, \
raft::host_matrix_view<const T, uint64_t, row_major> queries, \
raft::host_matrix_view<const IdxT, uint64_t, row_major> neighbor_candidates, \
raft::host_matrix_view<IdxT, uint64_t, row_major> indices, \
raft::host_matrix_view<float, uint64_t, row_major> distances, \
#define RAFT_INST(T, IdxT) \
extern template void refine<IdxT, T, float, int64_t>( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, int64_t, row_major> dataset, \
raft::device_matrix_view<const T, int64_t, row_major> queries, \
raft::device_matrix_view<const IdxT, int64_t, row_major> neighbor_candidates, \
raft::device_matrix_view<IdxT, int64_t, row_major> indices, \
raft::device_matrix_view<float, int64_t, row_major> distances, \
distance::DistanceType metric); \
\
extern template void refine<IdxT, T, float, int64_t>( \
raft::device_resources const& handle, \
raft::host_matrix_view<const T, int64_t, row_major> dataset, \
raft::host_matrix_view<const T, int64_t, row_major> queries, \
raft::host_matrix_view<const IdxT, int64_t, row_major> neighbor_candidates, \
raft::host_matrix_view<IdxT, int64_t, row_major> indices, \
raft::host_matrix_view<float, int64_t, row_major> distances, \
distance::DistanceType metric);

RAFT_INST(float, uint64_t);
RAFT_INST(uint8_t, uint64_t);
RAFT_INST(int8_t, uint64_t);
RAFT_INST(float, int64_t);
RAFT_INST(uint8_t, int64_t);
RAFT_INST(int8_t, int64_t);

#undef RAFT_INST
} // namespace raft::neighbors
2 changes: 2 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ann_quantized.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
#include <raft/label/classlabels.cuh>
#include <raft/neighbors/ivf_pq.cuh>

#include <raft/core/device_mdspan.hpp>

#include <rmm/cuda_stream_view.hpp>

#include <thrust/iterator/transform_iterator.h>
Expand Down
16 changes: 8 additions & 8 deletions cpp/include/raft_runtime/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ namespace raft::runtime::neighbors::ivf_pq {
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_INST_SEARCH(float, uint64_t);
RAFT_INST_SEARCH(int8_t, uint64_t);
RAFT_INST_SEARCH(uint8_t, uint64_t);
RAFT_INST_SEARCH(float, int64_t);
RAFT_INST_SEARCH(int8_t, int64_t);
RAFT_INST_SEARCH(uint8_t, int64_t);

#undef RAFT_INST_SEARCH

Expand Down Expand Up @@ -60,9 +60,9 @@ RAFT_INST_SEARCH(uint8_t, uint64_t);
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
raft::device_matrix_view<const IdxT, IdxT, row_major> new_indices);

RAFT_INST_BUILD_EXTEND(float, uint64_t)
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t)
RAFT_INST_BUILD_EXTEND(float, int64_t)
RAFT_INST_BUILD_EXTEND(int8_t, int64_t)
RAFT_INST_BUILD_EXTEND(uint8_t, int64_t)

#undef RAFT_INST_BUILD_EXTEND

Expand All @@ -78,7 +78,7 @@ RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t)
*/
void serialize(raft::device_resources const& handle,
const std::string& filename,
const raft::neighbors::ivf_pq::index<uint64_t>& index);
const raft::neighbors::ivf_pq::index<int64_t>& index);

/**
* Load index from file.
Expand All @@ -92,6 +92,6 @@ void serialize(raft::device_resources const& handle,
*/
void deserialize(raft::device_resources const& handle,
const std::string& filename,
raft::neighbors::ivf_pq::index<uint64_t>* index);
raft::neighbors::ivf_pq::index<int64_t>* index);

} // namespace raft::runtime::neighbors::ivf_pq
36 changes: 18 additions & 18 deletions cpp/include/raft_runtime/neighbors/refine.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,26 @@

namespace raft::runtime::neighbors {

#define RAFT_INST_REFINE(IDX_T, DATA_T) \
void refine(raft::device_resources const& handle, \
raft::device_matrix_view<const DATA_T, uint64_t, row_major> dataset, \
raft::device_matrix_view<const DATA_T, uint64_t, row_major> queries, \
raft::device_matrix_view<const IDX_T, uint64_t, row_major> neighbor_candidates, \
raft::device_matrix_view<IDX_T, uint64_t, row_major> indices, \
raft::device_matrix_view<float, uint64_t, row_major> distances, \
distance::DistanceType metric); \
\
void refine(raft::device_resources const& handle, \
raft::host_matrix_view<const DATA_T, uint64_t, row_major> dataset, \
raft::host_matrix_view<const DATA_T, uint64_t, row_major> queries, \
raft::host_matrix_view<const IDX_T, uint64_t, row_major> neighbor_candidates, \
raft::host_matrix_view<IDX_T, uint64_t, row_major> indices, \
raft::host_matrix_view<float, uint64_t, row_major> distances, \
#define RAFT_INST_REFINE(IDX_T, DATA_T) \
void refine(raft::device_resources const& handle, \
raft::device_matrix_view<const DATA_T, int64_t, row_major> dataset, \
raft::device_matrix_view<const DATA_T, int64_t, row_major> queries, \
raft::device_matrix_view<const IDX_T, int64_t, row_major> neighbor_candidates, \
raft::device_matrix_view<IDX_T, int64_t, row_major> indices, \
raft::device_matrix_view<float, int64_t, row_major> distances, \
distance::DistanceType metric); \
\
void refine(raft::device_resources const& handle, \
raft::host_matrix_view<const DATA_T, int64_t, row_major> dataset, \
raft::host_matrix_view<const DATA_T, int64_t, row_major> queries, \
raft::host_matrix_view<const IDX_T, int64_t, row_major> neighbor_candidates, \
raft::host_matrix_view<IDX_T, int64_t, row_major> indices, \
raft::host_matrix_view<float, int64_t, row_major> distances, \
distance::DistanceType metric);

RAFT_INST_REFINE(uint64_t, float);
RAFT_INST_REFINE(uint64_t, uint8_t);
RAFT_INST_REFINE(uint64_t, int8_t);
RAFT_INST_REFINE(int64_t, float);
RAFT_INST_REFINE(int64_t, uint8_t);
RAFT_INST_REFINE(int64_t, int8_t);

#undef RAFT_INST_REFINE

Expand Down
6 changes: 3 additions & 3 deletions cpp/src/distance/neighbors/ivfpq_build.cu
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ namespace raft::runtime::neighbors::ivf_pq {
raft::neighbors::ivf_pq::extend<T, IdxT>(handle, idx, new_vectors, new_indices); \
}

RAFT_INST_BUILD_EXTEND(float, uint64_t);
RAFT_INST_BUILD_EXTEND(int8_t, uint64_t);
RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t);
RAFT_INST_BUILD_EXTEND(float, int64_t);
RAFT_INST_BUILD_EXTEND(int8_t, int64_t);
RAFT_INST_BUILD_EXTEND(uint8_t, int64_t);

#undef RAFT_INST_BUILD_EXTEND

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/distance/neighbors/ivfpq_deserialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ namespace raft::runtime::neighbors::ivf_pq {

void deserialize(raft::device_resources const& handle,
const std::string& filename,
raft::neighbors::ivf_pq::index<uint64_t>* index)
raft::neighbors::ivf_pq::index<int64_t>* index)
{
if (!index) { RAFT_FAIL("Invalid index pointer"); }
*index = raft::neighbors::ivf_pq::deserialize<uint64_t>(handle, filename);
*index = raft::neighbors::ivf_pq::deserialize<int64_t>(handle, filename);
};
} // namespace raft::runtime::neighbors::ivf_pq
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::ivf_pq {
handle, params, idx, queries, k, neighbors, distances); \
}

RAFT_SEARCH_INST(float, uint64_t);
RAFT_SEARCH_INST(float, int64_t);

#undef RAFT_INST_SEARCH

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::ivf_pq {
handle, params, idx, queries, k, neighbors, distances); \
}

RAFT_SEARCH_INST(int8_t, uint64_t);
RAFT_SEARCH_INST(int8_t, int64_t);

#undef RAFT_INST_SEARCH

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ namespace raft::runtime::neighbors::ivf_pq {
handle, params, idx, queries, k, neighbors, distances); \
}

RAFT_SEARCH_INST(uint8_t, uint64_t);
RAFT_SEARCH_INST(uint8_t, int64_t);

#undef RAFT_INST_SEARCH

Expand Down
2 changes: 1 addition & 1 deletion cpp/src/distance/neighbors/ivfpq_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ namespace raft::runtime::neighbors::ivf_pq {

void serialize(raft::device_resources const& handle,
const std::string& filename,
const raft::neighbors::ivf_pq::index<uint64_t>& index)
const raft::neighbors::ivf_pq::index<int64_t>& index)
{
raft::neighbors::ivf_pq::serialize(handle, filename, index);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
namespace raft::runtime::neighbors {

void refine(raft::device_resources const& handle,
raft::host_matrix_view<const int8_t, uint64_t, row_major> dataset,
raft::host_matrix_view<const int8_t, uint64_t, row_major> queries,
raft::host_matrix_view<const uint64_t, uint64_t, row_major> neighbor_candidates,
raft::host_matrix_view<uint64_t, uint64_t, row_major> indices,
raft::host_matrix_view<float, uint64_t, row_major> distances,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::device_matrix_view<const float, int64_t, row_major> queries,
raft::device_matrix_view<const int64_t, int64_t, row_major> neighbor_candidates,
raft::device_matrix_view<int64_t, int64_t, row_major> indices,
raft::device_matrix_view<float, int64_t, row_major> distances,
distance::DistanceType metric)
{
raft::neighbors::refine<uint64_t, int8_t, float, uint64_t>(
raft::neighbors::refine<int64_t, float, float, int64_t>(
handle, dataset, queries, neighbor_candidates, indices, distances, metric);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,14 @@
namespace raft::runtime::neighbors {

void refine(raft::device_resources const& handle,
raft::host_matrix_view<const uint8_t, uint64_t, row_major> dataset,
raft::host_matrix_view<const uint8_t, uint64_t, row_major> queries,
raft::host_matrix_view<const uint64_t, uint64_t, row_major> neighbor_candidates,
raft::host_matrix_view<uint64_t, uint64_t, row_major> indices,
raft::host_matrix_view<float, uint64_t, row_major> distances,
raft::device_matrix_view<const int8_t, int64_t, row_major> dataset,
raft::device_matrix_view<const int8_t, int64_t, row_major> queries,
raft::device_matrix_view<const int64_t, int64_t, row_major> neighbor_candidates,
raft::device_matrix_view<int64_t, int64_t, row_major> indices,
raft::device_matrix_view<float, int64_t, row_major> distances,
distance::DistanceType metric)
{
raft::neighbors::refine<uint64_t, uint8_t, float, uint64_t>(
raft::neighbors::refine<int64_t, int8_t, float, int64_t>(
handle, dataset, queries, neighbor_candidates, indices, distances, metric);
}

Expand Down
Loading

0 comments on commit 6b30de9

Please sign in to comment.