Skip to content

Commit

Permalink
CAGRA Python wrappers (#1665)
Browse files Browse the repository at this point in the history
First verstion of a CAGRA API in pylibraft. 

Todos:

- [x] C++ raft_runtime instantiations and void overloads
- [x] Cython API
- [x] Solve issue of `cagra_types.hpp` including `#include <raft/util/pow2_utils.cuh>` that makes it need nvcc, blocking a clean C++ only cython build
- [x] Check in pytests
- [x] Add examples to docstrings 
- [x] Accommodate for parameter rename of #1676
- [x] Accomodate changes of #1664 
- [x] Move out of experimental namespace

Authors:
   - Dante Gama Dessavre (https://github.com/dantegd)
   - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
   - Corey J. Nolet (https://github.com/cjnolet)
  • Loading branch information
dantegd authored Jul 31, 2023
1 parent 8d9d682 commit c87af2d
Show file tree
Hide file tree
Showing 27 changed files with 1,832 additions and 33 deletions.
3 changes: 3 additions & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/distance/pairwise_distance.cu
src/raft_runtime/matrix/select_k_float_int64_t.cu
src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu
src/raft_runtime/neighbors/cagra_build.cu
src/raft_runtime/neighbors/cagra_search.cu
src/raft_runtime/neighbors/cagra_serialize.cu
src/raft_runtime/neighbors/ivf_flat_build.cu
src/raft_runtime/neighbors/ivf_flat_search.cu
src/raft_runtime/neighbors/ivf_flat_serialize.cu
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ namespace raft::neighbors::cagra {
* optimized_graph.view());
* @endcode
*
* @tparam T data element type
* @tparam DataT data element type
* @tparam IdxT type of the dataset vector indices
*
* @param[in] res raft resources
Expand Down
8 changes: 3 additions & 5 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
#include <raft/core/resources.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/util/integer_utils.hpp>
#include <raft/util/pow2_utils.cuh>

#include <memory>
#include <optional>
Expand Down Expand Up @@ -113,7 +112,6 @@ static_assert(std::is_aggregate_v<search_params>);
*/
template <typename T, typename IdxT>
struct index : ann::index {
using AlignDim = raft::Pow2<16 / sizeof(T)>;
static_assert(!raft::is_narrowing_v<uint32_t, IdxT>,
"IdxT must be able to represent all values of uint32_t");

Expand Down Expand Up @@ -252,7 +250,7 @@ struct index : ann::index {
void update_dataset(raft::resources const& res,
raft::device_matrix_view<const T, int64_t, row_major> dataset)
{
if (dataset.extent(1) % AlignDim::Value != 0) {
if (dataset.extent(1) * sizeof(T) % 16 != 0) {
RAFT_LOG_DEBUG("Creating a padded copy of CAGRA dataset in device memory");
copy_padded(res, dataset);
} else {
Expand Down Expand Up @@ -308,8 +306,8 @@ struct index : ann::index {
void copy_padded(raft::resources const& res,
mdspan<const T, matrix_extent<int64_t>, row_major, data_accessor> dataset)
{
dataset_ =
make_device_matrix<T, int64_t>(res, dataset.extent(0), AlignDim::roundUp(dataset.extent(1)));
size_t padded_dim = round_up_safe<size_t>(dataset.extent(1) * sizeof(T), 16) / sizeof(T);
dataset_ = make_device_matrix<T, int64_t>(res, dataset.extent(0), padded_dim);
if (dataset_.extent(1) == dataset.extent(1)) {
raft::copy(dataset_.data_handle(),
dataset.data_handle(),
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/neighbors/detail/cagra/cagra_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ void build_knn_graph(raft::resources const& res,
1e-6;
const auto throughput = num_queries_done / time;

RAFT_LOG_INFO(
RAFT_LOG_DEBUG(
"# Search %12lu / %12lu (%3.2f %%), %e queries/sec, %.2f minutes ETA, self included = "
"%3.2f %% \r",
num_queries_done,
Expand Down
91 changes: 91 additions & 0 deletions cpp/include/raft_runtime/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <raft/neighbors/cagra_types.hpp>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <string>

#include <raft/core/device_mdspan.hpp>
#include <raft/core/host_device_accessor.hpp>
#include <raft/core/mdspan.hpp>

namespace raft::runtime::neighbors::cagra {

// Using device and host_matrix_view avoids needing to typedef mutltiple mdspans based on accessors
#define RAFT_INST_CAGRA_FUNCS(T, IdxT) \
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT>; \
\
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT>; \
\
void build_device(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx); \
\
void build_host(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx); \
\
void search(raft::resources const& handle, \
raft::neighbors::cagra::search_params const& params, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \
raft::device_matrix_view<float, int64_t, row_major> distances); \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::cagra::index<T, IdxT>* index);

RAFT_INST_CAGRA_FUNCS(float, uint32_t);
RAFT_INST_CAGRA_FUNCS(int8_t, uint32_t);
RAFT_INST_CAGRA_FUNCS(uint8_t, uint32_t);

#undef RAFT_INST_CAGRA_FUNCS

#define RAFT_INST_CAGRA_OPTIMIZE(IdxT) \
void optimize_device(raft::resources const& res, \
raft::device_matrix_view<IdxT, int64_t, row_major> knn_graph, \
raft::host_matrix_view<IdxT, int64_t, row_major> new_graph); \
\
void optimize_host(raft::resources const& res, \
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph, \
raft::host_matrix_view<IdxT, int64_t, row_major> new_graph);

RAFT_INST_CAGRA_OPTIMIZE(uint32_t);

#undef RAFT_INST_CAGRA_OPTIMIZE

} // namespace raft::runtime::neighbors::cagra
81 changes: 81 additions & 0 deletions cpp/src/raft_runtime/neighbors/cagra_build.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/cagra.cuh>
#include <raft/neighbors/ivf_pq.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
#include <raft_runtime/neighbors/cagra.hpp>

namespace raft::runtime::neighbors::cagra {

#define RAFT_INST_CAGRA_BUILD(T, IdxT) \
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT> \
{ \
return raft::neighbors::cagra::build<T, IdxT>(handle, params, dataset); \
} \
\
auto build(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset) \
->raft::neighbors::cagra::index<T, IdxT> \
{ \
return raft::neighbors::cagra::build<T, IdxT>(handle, params, dataset); \
} \
\
void build_device(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::device_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx) \
{ \
idx = build(handle, params, dataset); \
} \
\
void build_host(raft::resources const& handle, \
const raft::neighbors::cagra::index_params& params, \
raft::host_matrix_view<const T, int64_t, row_major> dataset, \
raft::neighbors::cagra::index<T, IdxT>& idx) \
{ \
idx = build(handle, params, dataset); \
}

RAFT_INST_CAGRA_BUILD(float, uint32_t);
RAFT_INST_CAGRA_BUILD(int8_t, uint32_t);
RAFT_INST_CAGRA_BUILD(uint8_t, uint32_t);

#undef RAFT_INST_CAGRA_BUILD

#define RAFT_INST_CAGRA_OPTIMIZE(IdxT) \
void optimize_device(raft::resources const& handle, \
raft::device_matrix_view<IdxT, int64_t, row_major> knn_graph, \
raft::host_matrix_view<IdxT, int64_t, row_major> new_graph) \
{ \
raft::neighbors::cagra::optimize(handle, knn_graph, new_graph); \
} \
void optimize_host(raft::resources const& handle, \
raft::host_matrix_view<IdxT, int64_t, row_major> knn_graph, \
raft::host_matrix_view<IdxT, int64_t, row_major> new_graph) \
{ \
raft::neighbors::cagra::optimize(handle, knn_graph, new_graph); \
}

RAFT_INST_CAGRA_OPTIMIZE(uint32_t);

#undef RAFT_INST_CAGRA_OPTIMIZE

} // namespace raft::runtime::neighbors::cagra
39 changes: 39 additions & 0 deletions cpp/src/raft_runtime/neighbors/cagra_search.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <raft/neighbors/cagra.cuh>
#include <raft_runtime/neighbors/cagra.hpp>

namespace raft::runtime::neighbors::cagra {

#define RAFT_INST_CAGRA_SEARCH(T, IdxT) \
void search(raft::resources const& handle, \
raft::neighbors::cagra::search_params const& params, \
const raft::neighbors::cagra::index<T, IdxT>& index, \
raft::device_matrix_view<const T, int64_t, row_major> queries, \
raft::device_matrix_view<IdxT, int64_t, row_major> neighbors, \
raft::device_matrix_view<float, int64_t, row_major> distances) \
{ \
raft::neighbors::cagra::search<T, IdxT>(handle, params, index, queries, neighbors, distances); \
}

RAFT_INST_CAGRA_SEARCH(float, uint32_t);
RAFT_INST_CAGRA_SEARCH(int8_t, uint32_t);
RAFT_INST_CAGRA_SEARCH(uint8_t, uint32_t);

#undef RAFT_INST_CAGRA_SEARCH

} // namespace raft::runtime::neighbors::cagra
65 changes: 65 additions & 0 deletions cpp/src/raft_runtime/neighbors/cagra_serialize.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <sstream>
#include <string>

#include <raft/core/device_resources.hpp>
#include <raft/neighbors/cagra_serialize.cuh>
#include <raft/neighbors/cagra_types.hpp>
#include <raft_runtime/neighbors/cagra.hpp>

namespace raft::runtime::neighbors::cagra {

#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
raft::neighbors::cagra::serialize(handle, filename, index); \
}; \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::cagra::deserialize<DTYPE, uint32_t>(handle, filename); \
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize(handle, os, index); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::cagra::index<DTYPE, uint32_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::cagra::deserialize<DTYPE, uint32_t>(handle, is); \
}

RAFT_INST_CAGRA_SERIALIZE(float);
RAFT_INST_CAGRA_SERIALIZE(int8_t);
RAFT_INST_CAGRA_SERIALIZE(uint8_t);

#undef RAFT_INST_CAGRA_SERIALIZE
} // namespace raft::runtime::neighbors::cagra
14 changes: 14 additions & 0 deletions docs/source/pylibraft_api/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,20 @@ Brute Force
.. autofunction:: pylibraft.neighbors.brute_force.knn


CAGRA
#####

.. autoclass:: pylibraft.neighbors.cagra.IndexParams
:members:

.. autofunction:: pylibraft.neighbors.cagra.build

.. autoclass:: pylibraft.neighbors.cagra.SearchParams
:members:

.. autofunction:: pylibraft.neighbors.cagra.search


IVF-Flat
########

Expand Down
1 change: 1 addition & 0 deletions python/pylibraft/pylibraft/common/ai_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def __init__(self, ai_arr):
ai_arr : array interface array
"""
self.ai_ = ai_arr.__array_interface__
self.from_cai = False

@property
def dtype(self):
Expand Down
1 change: 1 addition & 0 deletions python/pylibraft/pylibraft/common/cai_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(self, cai_arr):
__array_interface__=cai_arr.__cuda_array_interface__
)
super().__init__(helper)
self.from_cai = True


def wrap_array(array):
Expand Down
Loading

0 comments on commit c87af2d

Please sign in to comment.