diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index 2242629409..0a806402d2 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -47,12 +47,16 @@ namespace raft::neighbors::cagra { * @param[in] handle the raft handle * @param[in] os output stream * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. * */ template -void serialize(raft::resources const& handle, std::ostream& os, const index& index) +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset = true) { - detail::serialize(handle, os, index); + detail::serialize(handle, os, index, include_dataset); } /** @@ -77,14 +81,16 @@ void serialize(raft::resources const& handle, std::ostream& os, const index void serialize(raft::resources const& handle, const std::string& filename, - const index& index) + const index& index, + bool include_dataset = true) { - detail::serialize(handle, filename, index); + detail::serialize(handle, filename, index, include_dataset); } /** @@ -158,4 +164,4 @@ namespace raft::neighbors::experimental::cagra { using raft::neighbors::cagra::deserialize; using raft::neighbors::cagra::serialize; -} // namespace raft::neighbors::experimental::cagra \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 8d040c352b..2c9cbd2563 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -24,8 +24,7 @@ namespace raft::neighbors::cagra::detail { -// Serialization version 1. -constexpr int serialization_version = 2; +constexpr int serialization_version = 3; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -50,41 +49,53 @@ template struct check_index_layout), expecte * */ template -void serialize(raft::resources const& res, std::ostream& os, const index& index_) +void serialize(raft::resources const& res, + std::ostream& os, + const index& index_, + bool include_dataset) { RAFT_LOG_DEBUG( "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + serialize_scalar(res, os, serialization_version); serialize_scalar(res, os, index_.size()); serialize_scalar(res, os, index_.dim()); serialize_scalar(res, os, index_.graph_degree()); serialize_scalar(res, os, index_.metric()); - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); serialize_mdspan(res, os, index_.graph()); + + serialize_scalar(res, os, include_dataset); + if (include_dataset) { + auto dataset = index_.dataset(); + // Remove padding before saving the dataset + auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, host_dataset.view()); + } } template void serialize(raft::resources const& res, const std::string& filename, - const index& index_) + const index& index_, + bool include_dataset) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - detail::serialize(res, of, index_); + detail::serialize(res, of, index_, include_dataset); of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } @@ -102,6 +113,9 @@ void serialize(raft::resources const& res, template auto deserialize(raft::resources const& res, std::istream& is) -> index { + char dtype_string[4]; + is.read(dtype_string, 4); + auto ver = deserialize_scalar(res, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); @@ -113,9 +127,11 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto dataset = raft::make_host_matrix(n_rows, dim); auto graph = raft::make_host_matrix(n_rows, graph_degree); - deserialize_mdspan(res, is, dataset.view()); deserialize_mdspan(res, is, graph.view()); + bool has_dataset = deserialize_scalar(res, is); + if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); } + return index( res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); } diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index 6f56302776..c54ed32b77 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -56,14 +56,16 @@ namespace raft::runtime::neighbors::cagra { raft::device_matrix_view distances); \ void serialize_file(raft::resources const& handle, \ const std::string& filename, \ - const raft::neighbors::cagra::index& index); \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ \ void deserialize_file(raft::resources const& handle, \ const std::string& filename, \ raft::neighbors::cagra::index* index); \ void serialize(raft::resources const& handle, \ std::string& str, \ - const raft::neighbors::cagra::index& index); \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ \ void deserialize(raft::resources const& handle, \ const std::string& str, \ diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index be9788562a..69b48b93a4 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -27,9 +27,10 @@ namespace raft::runtime::neighbors::cagra { #define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ void serialize_file(raft::resources const& handle, \ const std::string& filename, \ - const raft::neighbors::cagra::index& index) \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ { \ - raft::neighbors::cagra::serialize(handle, filename, index); \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ }; \ \ void deserialize_file(raft::resources const& handle, \ @@ -41,10 +42,11 @@ namespace raft::runtime::neighbors::cagra { }; \ void serialize(raft::resources const& handle, \ std::string& str, \ - const raft::neighbors::cagra::index& index) \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ { \ std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index); \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ str = os.str(); \ } \ \ diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 89cb070afc..ea905d2089 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -137,6 +137,7 @@ struct AnnCagraInputs { int search_width; raft::distance::DistanceType metric; bool host_dataset; + bool include_serialized_dataset; // std::optional double min_recall; // = std::nullopt; }; @@ -217,9 +218,11 @@ class AnnCagraTest : public ::testing::TestWithParam { } else { index = cagra::build(handle_, index_params, database_view); }; - cagra::serialize(handle_, "cagra_index", index); + cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset); } + auto index = cagra::deserialize(handle_, "cagra_index"); + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.n_queries, ps.dim); @@ -340,9 +343,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam { void SetUp() override { - std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl; database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); - std::cout << "Done.\nRuning rng" << std::endl; raft::random::Rng r(1234ULL); if constexpr (std::is_same{}) { GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream()); @@ -379,6 +380,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {true}, {0.995}); auto inputs2 = raft::util::itertools::product( @@ -393,6 +395,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); inputs2 = @@ -407,6 +410,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {false}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -422,6 +426,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -437,6 +442,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false, true}, + {false}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -452,6 +458,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false, true}, + {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 6b202c2b69..17dd2d8bfd 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -30,6 +30,12 @@ from pylibraft.common.cpp.mdspan cimport ( from pylibraft.common.handle cimport device_resources from pylibraft.common.optional cimport make_optional, optional +# Cython doesn't like `const float` inside template parameters +# hack around this with using typedefs +ctypedef const float const_float +ctypedef const int8_t const_int8_t +ctypedef const uint8_t const_uint8_t + cdef device_matrix_view[float, int64_t, row_major] get_dmv_float( array, check_shape) except * @@ -49,6 +55,15 @@ cdef optional[device_matrix_view[int64_t, int64_t, row_major]] make_optional_vie cdef device_matrix_view[uint32_t, int64_t, row_major] get_dmv_uint32( array, check_shape) except * +cdef device_matrix_view[const_float, int64_t, row_major] get_const_dmv_float( + array, check_shape) except * + +cdef device_matrix_view[const_uint8_t, int64_t, row_major] get_const_dmv_uint8( + array, check_shape) except * + +cdef device_matrix_view[const_int8_t, int64_t, row_major] get_const_dmv_int8( + array, check_shape) except * + cdef host_matrix_view[float, int64_t, row_major] get_hmv_float( array, check_shape) except * @@ -63,3 +78,12 @@ cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64( cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32( array, check_shape) except * + +cdef host_matrix_view[const_float, int64_t, row_major] get_const_hmv_float( + array, check_shape) except * + +cdef host_matrix_view[const_uint8_t, int64_t, row_major] get_const_hmv_uint8( + array, check_shape) except * + +cdef host_matrix_view[const_int8_t, int64_t, row_major] get_const_hmv_int8( + array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 1219b1612d..7442a6bb89 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -193,6 +193,39 @@ cdef device_matrix_view[int64_t, int64_t, row_major] \ cai.data, shape[0], shape[1]) +cdef device_matrix_view[const_float, int64_t, row_major] \ + get_const_dmv_float(cai, check_shape) except *: + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[const_float, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[const_uint8_t, int64_t, row_major] \ + get_const_dmv_uint8(cai, check_shape) except *: + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[const_uint8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[const_int8_t, int64_t, row_major] \ + get_const_dmv_int8(cai, check_shape) except *: + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[const_int8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + cdef optional[device_matrix_view[int64_t, int64_t, row_major]] \ make_optional_view_int64(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501 return make_optional[device_matrix_view[int64_t, int64_t, row_major]](dmv) @@ -222,7 +255,6 @@ cdef host_matrix_view[float, int64_t, row_major] \ return make_host_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) - cdef host_matrix_view[uint8_t, int64_t, row_major] \ get_hmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: @@ -265,3 +297,36 @@ cdef host_matrix_view[uint32_t, int64_t, row_major] \ shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) return make_host_matrix_view[uint32_t, int64_t, row_major]( cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[const_float, int64_t, row_major] \ + get_const_hmv_float(cai, check_shape) except *: + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[const_float, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[const_uint8_t, int64_t, row_major] \ + get_const_hmv_uint8(cai, check_shape) except *: + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[const_uint8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[const_int8_t, int64_t, row_major] \ + get_const_hmv_int8(cai, check_shape) except *: + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[const_int8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index fbc1623cac..e0c59a5ed3 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -69,6 +69,12 @@ from pylibraft.common.cpp.mdspan cimport ( row_major, ) from pylibraft.common.mdspan cimport ( + get_const_dmv_float, + get_const_dmv_int8, + get_const_dmv_uint8, + get_const_hmv_float, + get_const_hmv_int8, + get_const_hmv_uint8, get_dmv_float, get_dmv_int8, get_dmv_int64, @@ -162,6 +168,31 @@ cdef class IndexFloat(Index): attr_str = [m_str] + attr_str return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")" + @auto_sync_handle + def update_dataset(self, dataset, handle=None): + """ Replace the dataset with a new dataset. + + Parameters + ---------- + dataset : array interface compliant matrix shape (n_samples, dim) + {handle_docstring} + """ + cdef device_resources* handle_ = \ + handle.getHandle() + + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype("float32")]) + + if dataset_ai.from_cai: + self.index[0].update_dataset(deref(handle_), + get_const_dmv_float(dataset_ai, + check_shape=True)) + else: + self.index[0].update_dataset(deref(handle_), + get_const_hmv_float(dataset_ai, + check_shape=True)) + @property def metric(self): return self.index[0].metric() @@ -195,6 +226,31 @@ cdef class IndexInt8(Index): self.index = new c_cagra.index[int8_t, uint32_t]( deref(handle_)) + @auto_sync_handle + def update_dataset(self, dataset, handle=None): + """ Replace the dataset with a new dataset. + + Parameters + ---------- + dataset : array interface compliant matrix shape (n_samples, dim) + {handle_docstring} + """ + cdef device_resources* handle_ = \ + handle.getHandle() + + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype("byte")]) + + if dataset_ai.from_cai: + self.index[0].update_dataset(deref(handle_), + get_const_dmv_int8(dataset_ai, + check_shape=True)) + else: + self.index[0].update_dataset(deref(handle_), + get_const_hmv_int8(dataset_ai, + check_shape=True)) + def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) attr_str = [attr + "=" + str(getattr(self, attr)) @@ -235,6 +291,31 @@ cdef class IndexUint8(Index): self.index = new c_cagra.index[uint8_t, uint32_t]( deref(handle_)) + @auto_sync_handle + def update_dataset(self, dataset, handle=None): + """ Replace the dataset with a new dataset. + + Parameters + ---------- + dataset : array interface compliant matrix shape (n_samples, dim) + {handle_docstring} + """ + cdef device_resources* handle_ = \ + handle.getHandle() + + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype("ubyte")]) + + if dataset_ai.from_cai: + self.index[0].update_dataset(deref(handle_), + get_const_dmv_uint8(dataset_ai, + check_shape=True)) + else: + self.index[0].update_dataset(deref(handle_), + get_const_hmv_uint8(dataset_ai, + check_shape=True)) + def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) attr_str = [attr + "=" + str(getattr(self, attr)) @@ -693,7 +774,7 @@ def search(SearchParams search_params, @auto_sync_handle -def save(filename, Index index, handle=None): +def save(filename, Index index, bool include_dataset=True, handle=None): """ Saves the index to a file. @@ -706,6 +787,12 @@ def save(filename, Index index, handle=None): Name of the file. index : Index Trained CAGRA index. + include_dataset : bool + Whether or not to write out the dataset along with the index. Including + the dataset in the serialized index will use extra disk space, and + might not be desired if you already have a copy of the dataset on + disk. If this option is set to false, you will have to call + `index.update_dataset(dataset)` after loading the index. {handle_docstring} Examples @@ -741,15 +828,17 @@ def save(filename, Index index, handle=None): if index.active_index_type == "float32": idx_float = index c_cagra.serialize_file( - deref(handle_), c_filename, deref(idx_float.index)) + deref(handle_), c_filename, deref(idx_float.index), + include_dataset) elif index.active_index_type == "byte": idx_int8 = index c_cagra.serialize_file( - deref(handle_), c_filename, deref(idx_int8.index)) + deref(handle_), c_filename, deref(idx_int8.index), include_dataset) elif index.active_index_type == "ubyte": idx_uint8 = index c_cagra.serialize_file( - deref(handle_), c_filename, deref(idx_uint8.index)) + deref(handle_), c_filename, deref(idx_uint8.index), + include_dataset) else: raise ValueError( "Index dtype %s not supported" % index.active_index_type) @@ -785,12 +874,9 @@ def load(filename, handle=None): cdef IndexInt8 idx_int8 cdef IndexUint8 idx_uint8 - # we extract the dtype from the array interfaces in the file - with open(filename, 'rb') as f: - type_str = f.read(700).decode("utf-8", errors='ignore') - - # Read description of the 6th element to get the datatype - dataset_dt = np.dtype(type_str.split('descr')[6][5:7]) + with open(filename, "rb") as f: + type_str = f.read(3).decode("utf8") + dataset_dt = np.dtype(type_str) if dataset_dt == np.float32: idx_float = IndexFloat(handle) diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 284c75b771..0c683bcd9b 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -36,6 +36,7 @@ from pylibraft.common.cpp.mdspan cimport ( row_major, ) from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport const_float, const_int8_t, const_uint8_t from pylibraft.common.optional cimport optional from pylibraft.distance.distance_type cimport DistanceType from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( @@ -90,6 +91,17 @@ cdef extern from "raft/neighbors/cagra_types.hpp" \ device_matrix_view[T, IdxT, row_major] dataset() device_matrix_view[T, IdxT, row_major] graph() + # hack: can't use the T template param here because of issues handling + # const w/ cython. introduce a new template param to get around this + void update_dataset[ValueT](const device_resources & handle, + host_matrix_view[ValueT, + int64_t, + row_major] dataset) + void update_dataset[ValueT](const device_resources & handle, + device_matrix_view[ValueT, + int64_t, + row_major] dataset) + cdef extern from "raft_runtime/neighbors/cagra.hpp" \ namespace "raft::runtime::neighbors::cagra" nogil: @@ -155,7 +167,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize(const device_resources& handle, string& str, - const index[float, uint32_t]& index) except + + const index[float, uint32_t]& index, + bool include_dataset) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -163,7 +176,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize(const device_resources& handle, string& str, - const index[uint8_t, uint32_t]& index) except + + const index[uint8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -171,7 +185,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize(const device_resources& handle, string& str, - const index[int8_t, uint32_t]& index) except + + const index[int8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -179,7 +194,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize_file(const device_resources& handle, const string& filename, - const index[float, uint32_t]& index) except + + const index[float, uint32_t]& index, + bool include_dataset) except + cdef void deserialize_file(const device_resources& handle, const string& filename, @@ -187,7 +203,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize_file(const device_resources& handle, const string& filename, - const index[uint8_t, uint32_t]& index) except + + const index[uint8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize_file(const device_resources& handle, const string& filename, @@ -195,7 +212,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize_file(const device_resources& handle, const string& filename, - const index[int8_t, uint32_t]& index) except + + const index[int8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize_file(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 435b2878a2..74e9f53b91 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -255,7 +255,8 @@ def test_cagra_search_params(params): @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte]) -def test_save_load(dtype): +@pytest.mark.parametrize("include_dataset", [True, False]) +def test_save_load(dtype, include_dataset): n_rows = 10000 n_cols = 50 n_queries = 1000 @@ -268,9 +269,14 @@ def test_save_load(dtype): assert index.trained filename = "my_index.bin" - cagra.save(filename, index) + cagra.save(filename, index, include_dataset=include_dataset) loaded_index = cagra.load(filename) + # if we didn't save the dataset with the index, we need to update the + # index with an already loaded copy + if not include_dataset: + loaded_index.update_dataset(dataset) + queries = generate_data((n_queries, n_cols), dtype) queries_device = device_ndarray(queries)