Skip to content

Commit

Permalink
Python API for CAGRA+HNSW (#246)
Browse files Browse the repository at this point in the history
Authors:
  - Divye Gala (https://github.com/divyegala)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #246
  • Loading branch information
divyegala authored Oct 3, 2024
1 parent 496427f commit 3c7f117
Show file tree
Hide file tree
Showing 21 changed files with 695 additions and 57 deletions.
11 changes: 10 additions & 1 deletion cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index);
*/
cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index);

/**
* @brief Get dimension of the CAGRA index
*
* @param[in] index CAGRA index
* @param[out] dim return dimension of the index
* @return cuvsError_t
*/
cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim);

/**
* @}
*/
Expand Down Expand Up @@ -338,7 +347,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res,
* with the same type of `queries`, such that `index.dtype.code ==
* queries.dl_tensor.dtype.code` Types for input are:
* 1. `queries`:
*` a. kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuvs/neighbors/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index);
* with the same type of `queries`, such that `index.dtype.code ==
* queries.dl_tensor.dtype.code`
* Supported types for input are:
* 1. `queries`: `kDLDataType.code == kDLFloat` or `kDLDataType.code == kDLInt` and
* `kDLDataType.bits = 32`
* 1. `queries`:
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 64`
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* NOTE: The HNSW index can only be searched by the hnswlib wrapper in cuVS,
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuvs/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ std::unique_ptr<index<int8_t>> from_cagra(

/**@}*/

// TODO: Filtered Search APIs: https://github.com/rapidsai/cuvs/issues/363

/**
* @defgroup hnsw_cpp_index_search Search hnswlib index
* @{
Expand Down Expand Up @@ -260,7 +262,7 @@ void search(raft::resources const& res,
void search(raft::resources const& res,
const search_params& params,
const index<uint8_t>& idx,
raft::host_matrix_view<const int, int64_t, raft::row_major> queries,
raft::host_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
raft::host_matrix_view<float, int64_t, raft::row_major> distances);

Expand Down Expand Up @@ -303,7 +305,7 @@ void search(raft::resources const& res,
void search(raft::resources const& res,
const search_params& params,
const index<int8_t>& idx,
raft::host_matrix_view<const int, int64_t, raft::row_major> queries,
raft::host_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
raft::host_matrix_view<float, int64_t, raft::row_major> distances);

Expand Down
8 changes: 8 additions & 0 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
});
}

extern "C" cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim)
{
return cuvs::core::translate_exceptions([=] {
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<float, uint32_t>*>(index->addr);
*dim = index_ptr->dim();
});
}

extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsCagraIndexParams_t params,
DLManagedTensor* dataset_tensor,
Expand Down
21 changes: 6 additions & 15 deletions cpp/src/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,9 @@ void serialize_to_hnswlib(raft::resources const& res,
os.write(reinterpret_cast<char*>(&curr_element_count), sizeof(std::size_t));
// Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t,
// labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) +
// dim * 4 + sizeof(labeltype)
auto size_data_per_element =
static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 + index_.dim() * 4 + 8);
// dim * sizeof(T) + sizeof(labeltype)
auto size_data_per_element = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 +
index_.dim() * sizeof(T) + 8);
os.write(reinterpret_cast<char*>(&size_data_per_element), sizeof(std::size_t));
// label_offset
std::size_t label_offset = size_data_per_element - 8;
Expand Down Expand Up @@ -185,18 +185,9 @@ void serialize_to_hnswlib(raft::resources const& res,
}

auto data_row = host_dataset.data_handle() + (index_.dim() * i);
if constexpr (std::is_same_v<T, float>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<float>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(float));
}
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<int>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
}
} else {
RAFT_FAIL("Unsupported dataset type while saving CAGRA dataset to HNSWlib format");
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<T>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
}

os.write(reinterpret_cast<char*>(&i), sizeof(std::size_t));
Expand Down
13 changes: 7 additions & 6 deletions cpp/src/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ std::unique_ptr<index<T>> from_cagra(raft::resources const& res,
return std::unique_ptr<index<T>>(hnsw_index);
}

template <typename QueriesT>
void get_search_knn_results(hnswlib::HierarchicalNSW<QueriesT> const* idx,
const QueriesT* query,
template <typename T>
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
const T* query,
int k,
uint64_t* indices,
float* distances)
Expand All @@ -127,11 +127,11 @@ void get_search_knn_results(hnswlib::HierarchicalNSW<QueriesT> const* idx,
}
}

template <typename T, typename QueriesT>
template <typename T>
void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const QueriesT, int64_t, raft::row_major> queries,
raft::host_matrix_view<const T, int64_t, raft::row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
raft::host_matrix_view<float, int64_t, raft::row_major> distances)
{
Expand All @@ -146,7 +146,8 @@ void search(raft::resources const& res,

idx.set_ef(params.ef);
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<QueriesT> const*>(idx.get_index());
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());

// when num_threads == 0, automatically maximize parallelism
if (params.num_threads) {
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/neighbors/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ CUVS_INST_HNSW_FROM_CAGRA(int8_t);

#undef CUVS_INST_HNSW_FROM_CAGRA

#define CUVS_INST_HNSW_SEARCH(T, QueriesT) \
void search(raft::resources const& res, \
const search_params& params, \
const index<T>& idx, \
raft::host_matrix_view<const QueriesT, int64_t, raft::row_major> queries, \
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors, \
raft::host_matrix_view<float, int64_t, raft::row_major> distances) \
{ \
detail::search<T, QueriesT>(res, params, idx, queries, neighbors, distances); \
#define CUVS_INST_HNSW_SEARCH(T) \
void search(raft::resources const& res, \
const search_params& params, \
const index<T>& idx, \
raft::host_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors, \
raft::host_matrix_view<float, int64_t, raft::row_major> distances) \
{ \
detail::search<T>(res, params, idx, queries, neighbors, distances); \
}

CUVS_INST_HNSW_SEARCH(float, float);
CUVS_INST_HNSW_SEARCH(uint8_t, int);
CUVS_INST_HNSW_SEARCH(int8_t, int);
CUVS_INST_HNSW_SEARCH(float);
CUVS_INST_HNSW_SEARCH(uint8_t);
CUVS_INST_HNSW_SEARCH(int8_t);

#undef CUVS_INST_HNSW_SEARCH

Expand Down
16 changes: 5 additions & 11 deletions cpp/src/neighbors/hnsw_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <cuvs/neighbors/hnsw.hpp>

namespace {
template <typename T, typename QueriesT>
template <typename T>
void _search(cuvsResources_t res,
cuvsHnswSearchParams params,
cuvsHnswIndex index,
Expand All @@ -46,7 +46,7 @@ void _search(cuvsResources_t res,
search_params.ef = params.ef;
search_params.num_threads = params.numThreads;

using queries_mdspan_type = raft::host_matrix_view<QueriesT const, int64_t, raft::row_major>;
using queries_mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
using neighbors_mdspan_type = raft::host_matrix_view<uint64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
Expand Down Expand Up @@ -127,16 +127,13 @@ extern "C" cuvsError_t cuvsHnswSearch(cuvsResources_t res,

auto index = *index_c_ptr;
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");
RAFT_EXPECTS(queries.dtype.bits == 32, "number of bits in queries dtype should be 32");

if (index.dtype.code == kDLFloat) {
_search<float, float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (index.dtype.code == kDLUInt) {
_search<uint8_t, int>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (index.dtype.code == kDLInt) {
_search<int8_t, int>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", queries.dtype.code, queries.dtype.bits);
}
Expand All @@ -152,13 +149,10 @@ extern "C" cuvsError_t cuvsHnswDeserialize(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename, dim, metric));
index->dtype.code = kDLFloat;
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t>(res, filename, dim, metric));
index->dtype.code = kDLInt;
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t>(res, filename, dim, metric));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dtype in file %s", filename);
}
Expand Down
1 change: 1 addition & 0 deletions docs/source/c_api/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ Nearest Neighbors
neighbors_ivf_flat_c.rst
neighbors_ivf_pq_c.rst
neighbors_cagra_c.rst
neighbors_hnsw_c.rst
1 change: 1 addition & 0 deletions docs/source/cpp_api/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ Nearest Neighbors

neighbors_bruteforce.rst
neighbors_cagra.rst
neighbors_hnsw.rst
neighbors_ivf_flat.rst
neighbors_ivf_pq.rst
neighbors_nn_descent.rst
Expand Down
1 change: 1 addition & 0 deletions docs/source/python_api/neighbors.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,6 @@ Nearest Neighbors

neighbors_brute_force.rst
neighbors_cagra.rst
neighbors_hnsw.rst
neighbors_ivf_flat.rst
neighbors_ivf_pq.rst
30 changes: 30 additions & 0 deletions docs/source/python_api/neighbors_hnsw.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
HNSW
====

This is a wrapper for hnswlib, to load a CAGRA index as an immutable HNSW index. The loaded HNSW index is only compatible in cuVS, and can be searched using wrapper functions.

.. role:: py(code)
:language: python
:class: highlight

Index search parameters
#######################

.. autoclass:: cuvs.neighbors.hnsw.SearchParams
:members:

Index
#####

.. autoclass:: cuvs.neighbors.hnsw.Index
:members:

Index Conversion
################

.. autofunction:: cuvs.neighbors.hnsw.from_cagra

Index search
############

.. autofunction:: cuvs.neighbors.hnsw.search
1 change: 1 addition & 0 deletions python/cuvs/cuvs/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

add_subdirectory(brute_force)
add_subdirectory(cagra)
add_subdirectory(hnsw)
add_subdirectory(ivf_flat)
add_subdirectory(ivf_pq)
add_subdirectory(filters)
Expand Down
17 changes: 17 additions & 0 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from libc.stdint cimport (
int8_t,
int32_t,
int64_t,
uint8_t,
uint32_t,
Expand Down Expand Up @@ -100,6 +101,8 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil:

cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index)

cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int32_t* dim)

cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsCagraIndexParams* params,
DLManagedTensor* dataset,
Expand All @@ -117,6 +120,20 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil:
cuvsCagraIndex_t index,
bool include_dataset) except +

cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res,
const char * filename,
cuvsCagraIndex_t index) except +

cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
const char * filename,
cuvsCagraIndex_t index) except +

cdef class Index:
"""
CAGRA index object. This object stores the trained CAGRA index state
which can be used to perform nearest neighbors searches.
"""

cdef cuvsCagraIndex_t index
cdef bool trained
cdef str active_index_type
17 changes: 9 additions & 8 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ from pylibraft.neighbors.common import _check_input_array

from libc.stdint cimport (
int8_t,
int32_t,
int64_t,
uint8_t,
uint32_t,
Expand Down Expand Up @@ -206,16 +207,9 @@ cdef class IndexParams:


cdef class Index:
"""
CAGRA index object. This object stores the trained CAGRA index state
which can be used to perform nearest neighbors searches.
"""

cdef cuvsCagraIndex_t index
cdef bool trained

def __cinit__(self):
self.trained = False
self.active_index_type = None
check_cuvs(cuvsCagraIndexCreate(&self.index))

def __dealloc__(self):
Expand All @@ -226,6 +220,12 @@ cdef class Index:
def trained(self):
return self.trained

@property
def dim(self):
cdef int32_t dim
check_cuvs(cuvsCagraIndexGetDims(self.index, &dim))
return dim

def __repr__(self):
# todo(dgd): update repr as we expose data through C API
attr_str = []
Expand Down Expand Up @@ -299,6 +299,7 @@ def build(IndexParams index_params, dataset, resources=None):
idx.index
))
idx.trained = True
idx.active_index_type = dataset_ai.dtype.name

return idx

Expand Down
Loading

0 comments on commit 3c7f117

Please sign in to comment.