Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Python API for CAGRA+HNSW #246

Merged
merged 26 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
09eb3ab
c api and tests
divyegala Jul 22, 2024
8bc035e
remove unneeded comment
divyegala Jul 22, 2024
f8327f5
Update ann_hnsw_c.cu
divyegala Jul 23, 2024
6c0df11
Update ann_hnsw_c.cu
divyegala Jul 23, 2024
8860a09
rename test
divyegala Jul 23, 2024
081eba5
passing python tests
divyegala Jul 24, 2024
5ba4fad
documentation
divyegala Jul 24, 2024
aa6c057
Merge branch 'branch-24.10' into hnsw-python-api
divyegala Jul 24, 2024
0c3d053
more docs
divyegala Jul 24, 2024
b47f92f
merging upstream
divyegala Sep 10, 2024
53bcf5d
Merge branch 'branch-24.10' into hnsw-python-api
cjnolet Sep 11, 2024
0c2d082
passing tests
divyegala Sep 11, 2024
360ac24
Merge branch 'branch-24.10' into hnsw-python-api
cjnolet Sep 16, 2024
8c53e22
address review
divyegala Sep 16, 2024
09de6d3
fix merge conflicts
divyegala Sep 27, 2024
ef98a4e
address review
divyegala Sep 27, 2024
97215f2
revert some changes
divyegala Sep 30, 2024
4acd22b
fix failing tests
divyegala Oct 1, 2024
6f86848
Merge branch 'branch-24.10' into hnsw-python-api
cjnolet Oct 2, 2024
006e77c
add some stream syncs in nn_descent
divyegala Oct 3, 2024
4d36c80
Merge branch 'branch-24.10' into hnsw-python-api
divyegala Oct 3, 2024
8d4d1a2
add more syncs, use thrust_policy
divyegala Oct 3, 2024
0409d12
Revert "add some stream syncs in nn_descent"
divyegala Oct 3, 2024
366af06
Revert "add more syncs, use thrust_policy"
divyegala Oct 3, 2024
ad40942
1000 rows in test
divyegala Oct 3, 2024
6460d2a
Merge remote-tracking branch 'upstream/branch-24.10' into hnsw-python…
divyegala Oct 3, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,15 @@ cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index);
*/
cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index);

/**
* @brief Get dimension of the CAGRA index
*
* @param[in] index CAGRA index
* @param[out] dim return dimension of the index
* @return cuvsError_t
*/
cuvsError_t cuvsCagraIndexDim(cuvsCagraIndex_t index, int* dim);
divyegala marked this conversation as resolved.
Show resolved Hide resolved

/**
* @}
*/
Expand Down Expand Up @@ -338,7 +347,7 @@ cuvsError_t cuvsCagraBuild(cuvsResources_t res,
* with the same type of `queries`, such that `index.dtype.code ==
* queries.dl_tensor.dtype.code` Types for input are:
* 1. `queries`:
*` a. kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 32`
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuvs/neighbors/hnsw.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,10 @@ cuvsError_t cuvsHnswIndexDestroy(cuvsHnswIndex_t index);
* with the same type of `queries`, such that `index.dtype.code ==
* queries.dl_tensor.dtype.code`
* Supported types for input are:
* 1. `queries`: `kDLDataType.code == kDLFloat` or `kDLDataType.code == kDLInt` and
* `kDLDataType.bits = 32`
* 1. `queries`:
* a. `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* b. `kDLDataType.code == kDLInt` and `kDLDataType.bits = 8`
* c. `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 8`
* 2. `neighbors`: `kDLDataType.code == kDLUInt` and `kDLDataType.bits = 64`
* 3. `distances`: `kDLDataType.code == kDLFloat` and `kDLDataType.bits = 32`
* NOTE: The HNSW index can only be searched by the hnswlib wrapper in cuVS,
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuvs/neighbors/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ void search(raft::resources const& res,
void search(raft::resources const& res,
const search_params& params,
const index<uint8_t>& idx,
raft::host_matrix_view<const int, int64_t, raft::row_major> queries,
raft::host_matrix_view<const uint8_t, int64_t, raft::row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
raft::host_matrix_view<float, int64_t, raft::row_major> distances);

Expand Down Expand Up @@ -303,7 +303,7 @@ void search(raft::resources const& res,
void search(raft::resources const& res,
const search_params& params,
const index<int8_t>& idx,
raft::host_matrix_view<const int, int64_t, raft::row_major> queries,
raft::host_matrix_view<const int8_t, int64_t, raft::row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
raft::host_matrix_view<float, int64_t, raft::row_major> distances);
divyegala marked this conversation as resolved.
Show resolved Hide resolved

Expand Down
8 changes: 8 additions & 0 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,14 @@ extern "C" cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index_c_ptr)
});
}

extern "C" cuvsError_t cuvsCagraIndexDim(cuvsCagraIndex_t index, int* dim)
{
return cuvs::core::translate_exceptions([=] {
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<float, uint32_t>*>(index->addr);
*dim = index_ptr->dim();
});
}

extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsCagraIndexParams_t params,
DLManagedTensor* dataset_tensor,
Expand Down
19 changes: 6 additions & 13 deletions cpp/src/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,9 @@ void serialize_to_hnswlib(raft::resources const& res,
os.write(reinterpret_cast<char*>(&curr_element_count), sizeof(std::size_t));
// Example:M: 16, dim = 128, data_t = float, index_t = uint32_t, list_size_type = uint32_t,
// labeltype: size_t size_data_per_element_ = M * 2 * sizeof(index_t) + sizeof(list_size_type) +
// dim * 4 + sizeof(labeltype)
auto size_data_per_element =
static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 + index_.dim() * 4 + 8);
// dim * sizeof(T) + sizeof(labeltype)
auto size_data_per_element = static_cast<std::size_t>(index_.graph_degree() * sizeof(IdxT) + 4 +
index_.dim() * sizeof(T) + 8);
os.write(reinterpret_cast<char*>(&size_data_per_element), sizeof(std::size_t));
// label_offset
std::size_t label_offset = size_data_per_element - 8;
Expand Down Expand Up @@ -184,16 +184,9 @@ void serialize_to_hnswlib(raft::resources const& res,
}

auto data_row = host_dataset.data_handle() + (index_.dim() * i);
if constexpr (std::is_same_v<T, float>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<float>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(float));
}
} else if constexpr (std::is_same_v<T, std::int8_t> or std::is_same_v<T, std::uint8_t>) {
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<int>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(int));
}
for (std::size_t j = 0; j < index_.dim(); ++j) {
auto data_elem = static_cast<T>(host_dataset(i, j));
os.write(reinterpret_cast<char*>(&data_elem), sizeof(T));
}

os.write(reinterpret_cast<char*>(&i), sizeof(std::size_t));
Expand Down
13 changes: 7 additions & 6 deletions cpp/src/neighbors/detail/hnsw.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,9 @@ std::unique_ptr<index<T>> from_cagra(raft::resources const& res,
return std::unique_ptr<index<T>>(hnsw_index);
}

template <typename QueriesT>
void get_search_knn_results(hnswlib::HierarchicalNSW<QueriesT> const* idx,
const QueriesT* query,
template <typename T>
void get_search_knn_results(hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const* idx,
const T* query,
int k,
uint64_t* indices,
float* distances)
Expand All @@ -127,11 +127,11 @@ void get_search_knn_results(hnswlib::HierarchicalNSW<QueriesT> const* idx,
}
}

template <typename T, typename QueriesT>
template <typename T>
void search(raft::resources const& res,
const search_params& params,
const index<T>& idx,
raft::host_matrix_view<const QueriesT, int64_t, raft::row_major> queries,
raft::host_matrix_view<const T, int64_t, raft::row_major> queries,
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors,
raft::host_matrix_view<float, int64_t, raft::row_major> distances)
{
Expand All @@ -146,7 +146,8 @@ void search(raft::resources const& res,

idx.set_ef(params.ef);
auto const* hnswlib_index =
reinterpret_cast<hnswlib::HierarchicalNSW<QueriesT> const*>(idx.get_index());
reinterpret_cast<hnswlib::HierarchicalNSW<typename hnsw_dist_t<T>::type> const*>(
idx.get_index());

// when num_threads == 0, automatically maximize parallelism
if (params.num_threads) {
Expand Down
24 changes: 12 additions & 12 deletions cpp/src/neighbors/hnsw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,20 @@ CUVS_INST_HNSW_FROM_CAGRA(int8_t);

#undef CUVS_INST_HNSW_FROM_CAGRA

#define CUVS_INST_HNSW_SEARCH(T, QueriesT) \
void search(raft::resources const& res, \
const search_params& params, \
const index<T>& idx, \
raft::host_matrix_view<const QueriesT, int64_t, raft::row_major> queries, \
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors, \
raft::host_matrix_view<float, int64_t, raft::row_major> distances) \
{ \
detail::search<T, QueriesT>(res, params, idx, queries, neighbors, distances); \
#define CUVS_INST_HNSW_SEARCH(T) \
void search(raft::resources const& res, \
const search_params& params, \
const index<T>& idx, \
raft::host_matrix_view<const T, int64_t, raft::row_major> queries, \
raft::host_matrix_view<uint64_t, int64_t, raft::row_major> neighbors, \
raft::host_matrix_view<float, int64_t, raft::row_major> distances) \
{ \
detail::search<T>(res, params, idx, queries, neighbors, distances); \
}

CUVS_INST_HNSW_SEARCH(float, float);
CUVS_INST_HNSW_SEARCH(uint8_t, int);
CUVS_INST_HNSW_SEARCH(int8_t, int);
CUVS_INST_HNSW_SEARCH(float);
CUVS_INST_HNSW_SEARCH(uint8_t);
CUVS_INST_HNSW_SEARCH(int8_t);

#undef CUVS_INST_HNSW_SEARCH

Expand Down
16 changes: 5 additions & 11 deletions cpp/src/neighbors/hnsw_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
#include <cuvs/neighbors/hnsw.hpp>

namespace {
template <typename T, typename QueriesT>
template <typename T>
void _search(cuvsResources_t res,
cuvsHnswSearchParams params,
cuvsHnswIndex index,
Expand All @@ -46,7 +46,7 @@ void _search(cuvsResources_t res,
search_params.ef = params.ef;
search_params.num_threads = params.numThreads;

using queries_mdspan_type = raft::host_matrix_view<QueriesT const, int64_t, raft::row_major>;
using queries_mdspan_type = raft::host_matrix_view<T const, int64_t, raft::row_major>;
using neighbors_mdspan_type = raft::host_matrix_view<uint64_t, int64_t, raft::row_major>;
using distances_mdspan_type = raft::host_matrix_view<float, int64_t, raft::row_major>;
auto queries_mds = cuvs::core::from_dlpack<queries_mdspan_type>(queries_tensor);
Expand Down Expand Up @@ -127,16 +127,13 @@ extern "C" cuvsError_t cuvsHnswSearch(cuvsResources_t res,

auto index = *index_c_ptr;
RAFT_EXPECTS(queries.dtype.code == index.dtype.code, "type mismatch between index and queries");
RAFT_EXPECTS(queries.dtype.bits == 32, "number of bits in queries dtype should be 32");

if (index.dtype.code == kDLFloat) {
_search<float, float>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<float>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (index.dtype.code == kDLUInt) {
_search<uint8_t, int>(
res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<uint8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else if (index.dtype.code == kDLInt) {
_search<int8_t, int>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
_search<int8_t>(res, *params, index, queries_tensor, neighbors_tensor, distances_tensor);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", queries.dtype.code, queries.dtype.bits);
}
Expand All @@ -152,13 +149,10 @@ extern "C" cuvsError_t cuvsHnswDeserialize(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename, dim, metric));
index->dtype.code = kDLFloat;
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t>(res, filename, dim, metric));
index->dtype.code = kDLInt;
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t>(res, filename, dim, metric));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dtype in file %s", filename);
}
Expand Down
1 change: 1 addition & 0 deletions python/cuvs/cuvs/neighbors/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

add_subdirectory(brute_force)
add_subdirectory(cagra)
add_subdirectory(hnsw)
add_subdirectory(ivf_flat)
add_subdirectory(ivf_pq)
add_subdirectory(filters)
Expand Down
17 changes: 17 additions & 0 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

from libc.stdint cimport (
int8_t,
int32_t,
int64_t,
uint8_t,
uint32_t,
Expand Down Expand Up @@ -100,6 +101,8 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil:

cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index)

cuvsError_t cuvsCagraIndexDim(cuvsCagraIndex_t index, int32_t* dim)

cuvsError_t cuvsCagraBuild(cuvsResources_t res,
cuvsCagraIndexParams* params,
DLManagedTensor* dataset,
Expand All @@ -117,6 +120,20 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil:
cuvsCagraIndex_t index,
bool include_dataset) except +

cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res,
const char * filename,
cuvsCagraIndex_t index) except +

cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
const char * filename,
cuvsCagraIndex_t index) except +

cdef class Index:
"""
CAGRA index object. This object stores the trained CAGRA index state
which can be used to perform nearest neighbors searches.
"""

cdef cuvsCagraIndex_t index
cdef bool trained
cdef str active_index_type
17 changes: 9 additions & 8 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ from pylibraft.neighbors.common import _check_input_array

from libc.stdint cimport (
int8_t,
int32_t,
int64_t,
uint8_t,
uint32_t,
Expand Down Expand Up @@ -206,16 +207,9 @@ cdef class IndexParams:


cdef class Index:
"""
CAGRA index object. This object stores the trained CAGRA index state
which can be used to perform nearest neighbors searches.
"""

cdef cuvsCagraIndex_t index
cdef bool trained

def __cinit__(self):
self.trained = False
self.active_index_type = None
check_cuvs(cuvsCagraIndexCreate(&self.index))

def __dealloc__(self):
Expand All @@ -226,6 +220,12 @@ cdef class Index:
def trained(self):
return self.trained

@property
def dim(self):
cdef int32_t dim
check_cuvs(cuvsCagraIndexDim(self.index, &dim))
return dim

def __repr__(self):
# todo(dgd): update repr as we expose data through C API
attr_str = []
Expand Down Expand Up @@ -299,6 +299,7 @@ def build(IndexParams index_params, dataset, resources=None):
idx.index
))
idx.trained = True
idx.active_index_type = dataset_ai.dtype.name

return idx

Expand Down
24 changes: 24 additions & 0 deletions python/cuvs/cuvs/neighbors/hnsw/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# =============================================================================
# Copyright (c) 2024, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except
# in compliance with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software distributed under the License
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express
# or implied. See the License for the specific language governing permissions and limitations under
# the License.
# =============================================================================

# Set the list of Cython files to build
set(cython_sources hnsw.pyx)
set(linked_libraries cuvs::cuvs cuvs::c_api)

# Build all of the Cython targets
rapids_cython_create_modules(
CXX
SOURCE_FILES "${cython_sources}"
LINKED_LIBRARIES "${linked_libraries}" ASSOCIATED_TARGETS cuvs MODULE_PREFIX neighbors_hnsw_
)
Empty file.
25 changes: 25 additions & 0 deletions python/cuvs/cuvs/neighbors/hnsw/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright (c) 2024, NVIDIA CORPORATION.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


from .hnsw import Index, SearchParams, from_cagra, load, save, search

__all__ = [
"Index",
"SearchParams",
"load",
"save",
"search",
"from_cagra",
]
Loading
Loading