Skip to content

Commit

Permalink
Fix ivf flat specialization header IdxT from uint64_t -> int64_t (#1358)
Browse files Browse the repository at this point in the history
The ivf_flat specialization header declarations used a wrong index type. 

The specializations for ivf flat are defined for int64_t. The raft_runtime interface also uses the int64_t instances. The ivf_flat specialization header, however, declared an interface using uint64_t.

This is fixed with this PR. Should also reduce compile times for `src/distance/neighbors/ivf_flat_search.cu.o`

Authors:
  - Allard Hendriksen (https://github.com/ahendriksen)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Divye Gala (https://github.com/divyegala)

URL: #1358
  • Loading branch information
ahendriksen authored Mar 20, 2023
1 parent 97f8ad7 commit 56ac43a
Showing 1 changed file with 27 additions and 27 deletions.
54 changes: 27 additions & 27 deletions cpp/include/raft/neighbors/specializations/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -20,35 +20,35 @@

namespace raft::neighbors::ivf_flat {

#define RAFT_INST(T, IdxT) \
extern template auto build(raft::device_resources const& handle, \
const index_params& params, \
raft::device_matrix_view<const T, uint64_t, row_major> dataset) \
->index<T, IdxT>; \
\
extern template auto extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const index<T, IdxT>& orig_index) \
->index<T, IdxT>; \
\
extern template void extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
extern template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
#define RAFT_INST(T, IdxT) \
extern template auto build(raft::device_resources const& handle, \
const index_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset) \
->index<T, IdxT>; \
\
extern template auto extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const index<T, IdxT>& orig_index) \
->index<T, IdxT>; \
\
extern template void extend( \
raft::device_resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
extern template void search(raft::device_resources const&, \
raft::neighbors::ivf_flat::search_params const&, \
const raft::neighbors::ivf_flat::index<T, IdxT>&, \
raft::device_matrix_view<const T, IdxT, row_major>, \
raft::device_matrix_view<IdxT, IdxT, row_major>, \
raft::device_matrix_view<float, IdxT, row_major>);

RAFT_INST(float, uint64_t);
RAFT_INST(int8_t, uint64_t);
RAFT_INST(uint8_t, uint64_t);
RAFT_INST(float, int64_t);
RAFT_INST(int8_t, int64_t);
RAFT_INST(uint8_t, int64_t);

#undef RAFT_INST
} // namespace raft::neighbors::ivf_flat

0 comments on commit 56ac43a

Please sign in to comment.