Skip to content

Commit

Permalink
Accept host_mdspan for IVF-PQ build and extend (rapidsai#148)
Browse files Browse the repository at this point in the history
This PR enables host input arrays for `ivf_pq::build` and `ivf_pq::extend`.

closes rapidsai#120 
closes rapidsai#143

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai/cuvs#148
  • Loading branch information
tfeher authored May 28, 2024
1 parent b9e8a4b commit 1912355
Show file tree
Hide file tree
Showing 11 changed files with 527 additions and 200 deletions.
380 changes: 369 additions & 11 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp

Large diffs are not rendered by default.

3 changes: 1 addition & 2 deletions cpp/src/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,7 @@ void build_knn_graph(
}();

RAFT_LOG_DEBUG("# Building IVF-PQ index %s", model_name.c_str());
auto index = cuvs::neighbors::ivf_pq::detail::build<DataT, int64_t>(
res, *build_params, dataset.data_handle(), dataset.extent(0), dataset.extent(1));
auto index = cuvs::neighbors::ivf_pq::detail::build<DataT, int64_t>(res, *build_params, dataset);

//
// search top (k + 1) neighbors
Expand Down
39 changes: 2 additions & 37 deletions cpp/src/neighbors/ivf_pq/detail/generate_ivf_pq.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@
"""

build_include_macro = """
#include "../ivf_pq_build.cuh"
#include "ivf_pq_build_extend_inst.cuh"
"""
search_include_macro = """
#include "../ivf_pq_search.cuh"
Expand All @@ -61,42 +61,7 @@
uint8_t_int64_t=("uint8_t", "int64_t"),
)

build_extend_macro = """
#define CUVS_INST_IVF_PQ_BUILD_EXTEND(T, IdxT) \\
auto build(raft::resources const& handle, \\
const cuvs::neighbors::ivf_pq::index_params& params, \\
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset) \\
->cuvs::neighbors::ivf_pq::index<IdxT> \\
{ \\
return cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset); \\
} \\
\\
void build(raft::resources const& handle, \\
const cuvs::neighbors::ivf_pq::index_params& params, \\
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \\
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \\
{ \\
cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset, idx); \\
} \\
auto extend(raft::resources const& handle, \\
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \\
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \\
const cuvs::neighbors::ivf_pq::index<IdxT>& orig_index) \\
->cuvs::neighbors::ivf_pq::index<IdxT> \\
{ \\
return cuvs::neighbors::ivf_pq::detail::extend( \\
handle, new_vectors, new_indices, orig_index); \\
} \\
\\
void extend(raft::resources const& handle, \\
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \\
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \\
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \\
{ \\
cuvs::neighbors::ivf_pq::detail::extend( \\
handle, new_vectors, new_indices, idx); \\
}
"""
build_extend_macro = "" # moved to header ivf_pq_build_extend_inst.cuh

search_macro = """
#define CUVS_INST_IVF_PQ_SEARCH(T, IdxT) \\
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,9 @@

#include <cuvs/neighbors/ivf_pq.hpp>

#include "../ivf_pq_build.cuh"
#include "ivf_pq_build_extend_inst.cuh"

namespace cuvs::neighbors::ivf_pq {

#define CUVS_INST_IVF_PQ_BUILD_EXTEND(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset); \
} \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset, idx); \
} \
auto extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const cuvs::neighbors::ivf_pq::index<IdxT>& orig_index) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, orig_index); \
} \
\
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, idx); \
}
CUVS_INST_IVF_PQ_BUILD_EXTEND(float, int64_t);

#undef CUVS_INST_IVF_PQ_BUILD_EXTEND
Expand Down
93 changes: 93 additions & 0 deletions cpp/src/neighbors/ivf_pq/detail/ivf_pq_build_extend_inst.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/*
* NOTE: this file is used by generate_ivf_pq.py
*
*/

#include <cuvs/neighbors/ivf_pq.hpp>

#include "../ivf_pq_build.cuh"

namespace cuvs::neighbors::ivf_pq {

#define CUVS_INST_IVF_PQ_BUILD_EXTEND(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset); \
} \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset, idx); \
} \
\
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset); \
} \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::host_matrix_view<const T, int64_t, raft::row_major> dataset, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset, idx); \
} \
auto extend( \
raft::resources const& handle, \
raft::device_matrix_view<const T, int64_t, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, int64_t, raft::row_major>> new_indices, \
const cuvs::neighbors::ivf_pq::index<IdxT>& orig_index) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, orig_index); \
} \
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, int64_t, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, int64_t>> new_indices, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, idx); \
} \
auto extend(raft::resources const& handle, \
raft::host_matrix_view<const T, int64_t, raft::row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, int64_t>> new_indices, \
const cuvs::neighbors::ivf_pq::index<IdxT>& orig_index) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, orig_index); \
} \
\
void extend(raft::resources const& handle, \
raft::host_matrix_view<const T, int64_t, raft::row_major> new_vectors, \
std::optional<raft::host_vector_view<const IdxT, int64_t>> new_indices, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, idx); \
}

} // namespace cuvs::neighbors::ivf_pq
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,9 @@

#include <cuvs/neighbors/ivf_pq.hpp>

#include "../ivf_pq_build.cuh"
#include "ivf_pq_build_extend_inst.cuh"

namespace cuvs::neighbors::ivf_pq {

#define CUVS_INST_IVF_PQ_BUILD_EXTEND(T, IdxT) \
auto build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset); \
} \
\
void build(raft::resources const& handle, \
const cuvs::neighbors::ivf_pq::index_params& params, \
raft::device_matrix_view<const T, IdxT, raft::row_major> dataset, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::build(handle, params, dataset, idx); \
} \
auto extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
const cuvs::neighbors::ivf_pq::index<IdxT>& orig_index) \
->cuvs::neighbors::ivf_pq::index<IdxT> \
{ \
return cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, orig_index); \
} \
\
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, raft::row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
cuvs::neighbors::ivf_pq::index<IdxT>* idx) \
{ \
cuvs::neighbors::ivf_pq::detail::extend(handle, new_vectors, new_indices, idx); \
}
CUVS_INST_IVF_PQ_BUILD_EXTEND(int8_t, int64_t);

#undef CUVS_INST_IVF_PQ_BUILD_EXTEND
Expand Down
Loading

0 comments on commit 1912355

Please sign in to comment.