Skip to content

Commit

Permalink
Expose serialization to the python / c-api (rapidsai#164)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai/cuvs#164
  • Loading branch information
benfred authored May 30, 2024
1 parent a4c3a6f commit 35f8b84
Show file tree
Hide file tree
Showing 10 changed files with 258 additions and 30 deletions.
45 changes: 45 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include <cuvs/core/c_api.h>
#include <dlpack/dlpack.h>
#include <stdbool.h>
#include <stdint.h>

#ifdef __cplusplus
Expand Down Expand Up @@ -382,6 +383,50 @@ cuvsError_t cuvsCagraSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup cagra_c_serialize CAGRA C-API serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <cuvs/neighbors/cagra.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsCagraBuild`
* cuvsCagraSerialize(res, "/path/to/index", index, true);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
cuvsError_t cuvsCagraSerialize(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index,
bool include_dataset);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index CAGRA index loaded disk
*/
cuvsError_t cuvsCagraDeserialize(cuvsResources_t res, const char* filename, cuvsCagraIndex_t index);
/**
* @}
*/
#ifdef __cplusplus
}
#endif
76 changes: 71 additions & 5 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/error.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
Expand Down Expand Up @@ -104,6 +105,27 @@ void _search(cuvsResources_t res,
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
}

template <typename T>
void _serialize(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index,
bool include_dataset)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(index->addr);
cuvs::neighbors::cagra::serialize_file(
*res_ptr, std::string(filename), *index_ptr, include_dataset);
}

template <typename T>
void* _deserialize(cuvsResources_t res, const char* filename)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = new cuvs::neighbors::cagra::index<T, uint32_t>(*res_ptr);
cuvs::neighbors::cagra::deserialize_file(*res_ptr, std::string(filename), index);
return index;
}

} // namespace

extern "C" cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index)
Expand Down Expand Up @@ -140,15 +162,13 @@ extern "C" cuvsError_t cuvsCagraBuild(cuvsResources_t res,
{
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;
index->dtype = dataset.dtype;
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr = reinterpret_cast<uintptr_t>(_build<float>(res, *params, dataset_tensor));
index->dtype.code = kDLFloat;
index->addr = reinterpret_cast<uintptr_t>(_build<float>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_build<int8_t>(res, *params, dataset_tensor));
index->dtype.code = kDLInt;
index->addr = reinterpret_cast<uintptr_t>(_build<int8_t>(res, *params, dataset_tensor));
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
index->addr = reinterpret_cast<uintptr_t>(_build<uint8_t>(res, *params, dataset_tensor));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
Expand Down Expand Up @@ -247,3 +267,49 @@ extern "C" cuvsError_t cuvsCagraSearchParamsDestroy(cuvsCagraSearchParams_t para
{
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
// read the numpy dtype from the beginning of the file
std::ifstream is(filename, std::ios::in | std::ios::binary);
if (!is) { RAFT_FAIL("Cannot open file %s", filename); }
char dtype_string[4];
is.read(dtype_string, 4);
auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4));

index->dtype.bits = dtype.itemsize * 8;
if (dtype.kind == 'f' && dtype.itemsize == 4) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float>(res, filename));
index->dtype.code = kDLFloat;
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t>(res, filename));
index->dtype.code = kDLInt;
} else if (dtype.kind == 'u' && dtype.itemsize == 1) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t>(res, filename));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dtype in file %s", filename);
}
});
}

extern "C" cuvsError_t cuvsCagraSerialize(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index,
bool include_dataset)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t>(res, filename, index, include_dataset);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
_serialize<uint8_t>(res, filename, index, include_dataset);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
});
}
4 changes: 0 additions & 4 deletions docs/source/working_with_ann_indexes_c.rst
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,3 @@ Searching an index
cuvsCagraIndexDestroy(index);
cuvsCagraIndexParamsDestroy(index_params);
cuvsResourcesDestroy(res);
Serializing an index
--------------------
6 changes: 0 additions & 6 deletions docs/source/working_with_ann_indexes_cpp.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,3 @@ Searching an index
cagra::search_params search_params;

cagra::search(res, search_params, index, queries, neighbors, distances);


Serializing an index
--------------------


5 changes: 0 additions & 5 deletions docs/source/working_with_ann_indexes_python.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,3 @@ Searching an index
index = // ... build index ...
neighbors, distances = cagra.search(search_params, index, queries, k)
Serializing an index
--------------------
8 changes: 0 additions & 8 deletions docs/source/working_with_ann_indexes_rust.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,3 @@ Building an index
Ok(())
}
Searching an index
------------------


Serializing an index
--------------------
4 changes: 4 additions & 0 deletions python/cuvs/cuvs/neighbors/cagra/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
IndexParams,
SearchParams,
build_index,
load,
save,
search,
)

Expand All @@ -28,5 +30,7 @@
"IndexParams",
"SearchParams",
"build_index",
"load",
"save",
"search",
]
10 changes: 10 additions & 0 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from libc.stdint cimport (
uint64_t,
uintptr_t,
)
from libcpp cimport bool

from cuvs.common.c_api cimport cuvsError_t, cuvsResources_t
from cuvs.common.cydlpack cimport DLDataType, DLManagedTensor
Expand Down Expand Up @@ -110,3 +111,12 @@ cdef extern from "cuvs/neighbors/cagra.h" nogil:
DLManagedTensor* queries,
DLManagedTensor* neighbors,
DLManagedTensor* distances) except +

cuvsError_t cuvsCagraSerialize(cuvsResources_t res,
const char * filename,
cuvsCagraIndex_t index,
bool include_dataset) except +

cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
const char * filename,
cuvsCagraIndex_t index) except +
80 changes: 78 additions & 2 deletions python/cuvs/cuvs/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ from cuvs.common.resources import auto_sync_resources

from cython.operator cimport dereference as deref
from libcpp cimport bool, cast
from libcpp.string cimport string

from cuvs.common cimport cydlpack

Expand Down Expand Up @@ -282,7 +283,6 @@ def build_index(IndexParams index_params, dataset, resources=None):
np.dtype('ubyte')])

cdef Index idx = Index()
cdef cuvsError_t build_status
cdef cydlpack.DLManagedTensor* dataset_dlpack = \
cydlpack.dlpack_c(dataset_ai)
cdef cuvsCagraIndexParams* params = index_params.params
Expand Down Expand Up @@ -544,7 +544,6 @@ def search(SearchParams search_params,
exp_rows=n_queries, exp_cols=k)

cdef cuvsCagraSearchParams* params = &search_params.params
cdef cuvsError_t search_status
cdef cydlpack.DLManagedTensor* queries_dlpack = \
cydlpack.dlpack_c(queries_cai)
cdef cydlpack.DLManagedTensor* neighbors_dlpack = \
Expand All @@ -564,3 +563,80 @@ def search(SearchParams search_params,
))

return (distances, neighbors)


@auto_sync_resources
def save(filename, Index index, bool include_dataset=True, resources=None):
"""
Saves the index to a file.
Saving / loading the index is experimental. The serialization format is
subject to change.
Parameters
----------
filename : string
Name of the file.
index : Index
Trained CAGRA index.
include_dataset : bool
Whether or not to write out the dataset along with the index. Including
the dataset in the serialized index will use extra disk space, and
might not be desired if you already have a copy of the dataset on
disk. If this option is set to false, you will have to call
`index.update_dataset(dataset)` after loading the index.
{resources_docstring}
Examples
--------
>>> import cupy as cp
>>> from cuvs.neighbors import cagra
>>> n_samples = 50000
>>> n_features = 50
>>> dataset = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
>>> # Build index
>>> index = cagra.build_index(cagra.IndexParams(), dataset)
>>> # Serialize and deserialize the cagra index built
>>> cagra.save("my_index.bin", index)
>>> index_loaded = cagra.load("my_index.bin")
"""
cdef string c_filename = filename.encode('utf-8')
cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()
check_cuvs(cuvsCagraSerialize(res,
c_filename.c_str(),
index.index,
include_dataset))


@auto_sync_resources
def load(filename, resources=None):
"""
Loads index from file.
Saving / loading the index is experimental. The serialization format is
subject to change, therefore loading an index saved with a previous
version of cuvs is not guaranteed to work.
Parameters
----------
filename : string
Name of the file.
{resources_docstring}
Returns
-------
index : Index
"""
cdef Index idx = Index()
cdef cuvsResources_t res = <cuvsResources_t>resources.get_c_obj()
cdef string c_filename = filename.encode('utf-8')

check_cuvs(cuvsCagraDeserialize(
res,
c_filename.c_str(),
idx.index
))
idx.trained = True
return idx
Loading

0 comments on commit 35f8b84

Please sign in to comment.