From 38b3147af76adfc7a1fa9de0bbd61d6b9ec6e128 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 5 Jun 2024 12:11:27 -0700 Subject: [PATCH] Allow serialization on streams (#173) This change allows serializing to a std::ostream and deserializaing from a std::istream. This also fixes some minor docstring issues in the C++ serialization api's. Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Divye Gala (https://github.com/divyegala) URL: https://github.com/rapidsai/cuvs/pull/173 --- cpp/include/cuvs/neighbors/cagra.hpp | 322 ++++++++++++++++-- cpp/include/cuvs/neighbors/ivf_flat.hpp | 122 +++---- cpp/include/cuvs/neighbors/ivf_pq.hpp | 38 +-- cpp/src/neighbors/cagra_c.cpp | 5 +- cpp/src/neighbors/cagra_serialize.cuh | 172 ++-------- cpp/src/neighbors/cagra_serialize_float.cu | 61 +--- cpp/src/neighbors/cagra_serialize_int8.cu | 61 +--- cpp/src/neighbors/cagra_serialize_uint8.cu | 61 +--- .../detail/cagra/cagra_serialize.cuh | 15 +- .../neighbors/ivf_flat/generate_ivf_flat.py | 35 +- .../neighbors/ivf_flat/ivf_flat_serialize.cuh | 28 ++ .../ivf_flat_serialize_float_int64_t.cu | 31 -- .../ivf_flat_serialize_int8_t_int64_t.cu | 31 -- .../ivf_flat_serialize_uint8_t_int64_t.cu | 31 -- .../neighbors/ivf_pq/ivf_pq_deserialize.cu | 11 +- cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cu | 12 +- cpp/test/neighbors/ann_cagra.cuh | 4 +- cpp/test/neighbors/ann_ivf_flat.cuh | 4 +- cpp/test/neighbors/ann_ivf_pq.cuh | 4 +- 19 files changed, 461 insertions(+), 587 deletions(-) diff --git a/cpp/include/cuvs/neighbors/cagra.hpp b/cpp/include/cuvs/neighbors/cagra.hpp index c5f2ab87d9..b0668eeb0c 100644 --- a/cpp/include/cuvs/neighbors/cagra.hpp +++ b/cpp/include/cuvs/neighbors/cagra.hpp @@ -727,55 +727,321 @@ void search(raft::resources const& res, * @defgroup cagra_cpp_serialize CAGRA serialize functions * @{ */ -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); - -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::cagra::index* index); + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + * + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::cagra::index& index, + bool include_dataset = true); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index the cagra index + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::cagra::index* index); + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::cagra::index& index, bool include_dataset = true); +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, is, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index the cagra index + */ void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::cagra::index* index); -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::cagra::index& index, + bool include_dataset = true); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index the cagra index + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::cagra::index* index); -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::cagra::index* index); +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::cagra::index& index, bool include_dataset = true); +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, is, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index the cagra index + */ void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::cagra::index* index); -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::cagra::index& index, - bool include_dataset = true); +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, filename, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::cagra::index& index, + bool include_dataset = true); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::cagra::index* index); + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, filename, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[out] index the cagra index + */ +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::cagra::index* index); + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = cuvs::neighbors::cagra::build(...);` + * cuvs::neighbors::cagra::serialize(handle, os, index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. + */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::cagra::index& index, bool include_dataset = true); +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @code{.cpp} + * #include + * #include + * + * raft::resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * cuvs::neighbors::cagra::index index; + * cuvs::neighbors::cagra::deserialize(handle, is, &index); + * @endcode + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[out] index the cagra index + */ void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::cagra::index* index); /** * @} diff --git a/cpp/include/cuvs/neighbors/ivf_flat.hpp b/cpp/include/cuvs/neighbors/ivf_flat.hpp index 38c7be612d..b5f3b54fac 100644 --- a/cpp/include/cuvs/neighbors/ivf_flat.hpp +++ b/cpp/include/cuvs/neighbors/ivf_flat.hpp @@ -1197,7 +1197,7 @@ void search_with_filtering( * // create a string with a filepath * std::string filename("/path/to/index"); * // create an index with `auto index = ivf_flat::build(...);` - * cuvs::serialize_file(handle, filename, index); + * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); * @endcode * * @param[in] handle the raft handle @@ -1205,9 +1205,9 @@ void search_with_filtering( * @param[in] index IVF-Flat index * */ -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::ivf_flat::index& index); +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_flat::index& index); /** * Load index from file. @@ -1225,7 +1225,7 @@ void serialize_file(raft::resources const& handle, * using T = float; // data element type * using IdxT = int64_t; // type of the index * // create an empty index with `ivf_flat::index index(handle, index_params, dim);` - * cuvs::deserialize_file(handle, filename, &index); + * cuvs::neighbors::ivf_flat::deserialize(handle, filename, &index); * @endcode * * @param[in] handle the raft handle @@ -1233,12 +1233,12 @@ void serialize_file(raft::resources const& handle, * @param[in] index IVF-Flat index * */ -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::ivf_flat::index* index); +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_flat::index* index); /** - * Write the index to an output string + * Write the index to an output stream * * Experimental, both the API and the serialization format are subject to change. * @@ -1248,23 +1248,23 @@ void deserialize_file(raft::resources const& handle, * * raft::resources handle; * - * // create an output string - * std::string str; + * // create an output stream + * std::ostream os(std::cout.rdbuf()); * // create an index with `auto index = ivf_flat::build(...);` - * cuvs::serialize(handle, str, index); + * cuvs::neighbors::ivf_flat::serialize(handle, os, index); * @endcode * * @param[in] handle the raft handle - * @param[out] str output string + * @param[in] os output stream * @param[in] index IVF-Flat index * */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::ivf_flat::index& index); /** - * Load index from input string + * Load index from input stream * * Experimental, both the API and the serialization format are subject to change. * @@ -1274,21 +1274,21 @@ void serialize(raft::resources const& handle, * * raft::resources handle; * - * // create an input string - * std::string str; + * // create an input stream + * std::istream is(std::cin.rdbuf()); * using T = float; // data element type * using IdxT = int64_t; // type of the index * // create an empty index with `ivf_flat::index index(handle, index_params, dim);` - * auto index = cuvs::deserialize(handle, str, &index); + * cuvs::neighbors::ivf_flat::deserialize(handle, is, &index); * @endcode * * @param[in] handle the raft handle - * @param[in] str output string + * @param[in] is input stream * @param[in] index IVF-Flat index * */ void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::ivf_flat::index* index); /** @@ -1305,7 +1305,7 @@ void deserialize(raft::resources const& handle, * // create a string with a filepath * std::string filename("/path/to/index"); * // create an index with `auto index = ivf_flat::build(...);` - * cuvs::serialize_file(handle, filename, index); + * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); * @endcode * * @param[in] handle the raft handle @@ -1313,9 +1313,9 @@ void deserialize(raft::resources const& handle, * @param[in] index IVF-Flat index * */ -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::ivf_flat::index& index); +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_flat::index& index); /** * Load index from file. @@ -1333,7 +1333,7 @@ void serialize_file(raft::resources const& handle, * using T = float; // data element type * using IdxT = int64_t; // type of the index * // create an empty index with `ivf_flat::index index(handle, index_params, dim);` - * cuvs::deserialize_file(handle, filename, &index); + * cuvs::neighbors::ivf_flat::deserialize(handle, filename, &index); * @endcode * * @param[in] handle the raft handle @@ -1341,12 +1341,12 @@ void serialize_file(raft::resources const& handle, * @param[in] index IVF-Flat index * */ -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::ivf_flat::index* index); +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_flat::index* index); /** - * Write the index to an output string + * Write the index to an output stream * * Experimental, both the API and the serialization format are subject to change. * @@ -1356,23 +1356,23 @@ void deserialize_file(raft::resources const& handle, * * raft::resources handle; * - * // create an output string - * std::string str; + * // create an output stream + * std::ostream os(std::cout.rdbuf()); * // create an index with `auto index = ivf_flat::build(...);` - * cuvs::serialize(handle, str, index); + * cuvs::neighbors::ivf_flat::serialize(handle, os, index); * @endcode * * @param[in] handle the raft handle - * @param[out] str output string + * @param[in] os output stream * @param[in] index IVF-Flat index * */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::ivf_flat::index& index); /** - * Load index from input string + * Load index from input stream * * Experimental, both the API and the serialization format are subject to change. * @@ -1382,21 +1382,21 @@ void serialize(raft::resources const& handle, * * raft::resources handle; * - * // create an input string - * std::string str; + * // create an input stream + * std::istream is(std::cin.rdbuf()); * using T = float; // data element type * using IdxT = int64_t; // type of the index * // create an empty index with `ivf_flat::index index(handle, index_params, dim);` - * auto index = cuvs::deserialize(handle, str, &index); + * cuvs::neighbors::ivf_flat::deserialize(handle, is, &index); * @endcode * * @param[in] handle the raft handle - * @param[in] str output string + * @param[in] is input stream * @param[in] index IVF-Flat index * */ void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::ivf_flat::index* index); /** @@ -1413,7 +1413,7 @@ void deserialize(raft::resources const& handle, * // create a string with a filepath * std::string filename("/path/to/index"); * // create an index with `auto index = ivf_flat::build(...);` - * cuvs::serialize_file(handle, filename, index); + * cuvs::neighbors::ivf_flat::serialize(handle, filename, index); * @endcode * * @param[in] handle the raft handle @@ -1421,9 +1421,9 @@ void deserialize(raft::resources const& handle, * @param[in] index IVF-Flat index * */ -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::ivf_flat::index& index); +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_flat::index& index); /** * Load index from file. @@ -1441,7 +1441,7 @@ void serialize_file(raft::resources const& handle, * using T = float; // data element type * using IdxT = int64_t; // type of the index * // create an empty index with ivf_flat::index index(handle, index_params, dim);` - * cuvs::deserialize_file(handle, filename, &index); + * cuvs::neighbors::ivf_flat::deserialize(handle, filename, &index); * @endcode * * @param[in] handle the raft handle @@ -1449,12 +1449,12 @@ void serialize_file(raft::resources const& handle, * @param[in] index IVF-Flat index * */ -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::ivf_flat::index* index); +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_flat::index* index); /** - * Write the index to an output string + * Write the index to an output stream * * Experimental, both the API and the serialization format are subject to change. * @@ -1464,23 +1464,23 @@ void deserialize_file(raft::resources const& handle, * * raft::resources handle; * - * // create an output string - * std::string str; + * // create an output stream + * std::ostream os(std::cout.rdbuf()); * // create an index with `auto index = ivf_flat::build(...);` - * cuvs::serialize(handle, str, index); + * cuvs::neighbors::ivf_flat::serialize(handle, os, index); * @endcode * * @param[in] handle the raft handle - * @param[out] str output string + * @param[in] os output stream * @param[in] index IVF-Flat index * */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::ivf_flat::index& index); /** - * Load index from input string + * Load index from input stream * * Experimental, both the API and the serialization format are subject to change. * @@ -1490,21 +1490,21 @@ void serialize(raft::resources const& handle, * * raft::resources handle; * - * // create an input string - * std::string str; + * // create an input stream + * std::istream is(std::cin.rdbuf()); * using T = float; // data element type * using IdxT = int64_t; // type of the index * // create an empty index with `ivf_flat::index index(handle, index_params, dim);` - * auto index = cuvs::deserialize(handle, str, &index); + * cuvs::neighbors::ivf_flat::deserialize(handle, is, &index); * @endcode * * @param[in] handle the raft handle - * @param[in] str output string + * @param[in] is input stream * @param[in] index IVF-Flat index * */ void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::ivf_flat::index* index); /** @@ -1866,4 +1866,4 @@ void reset_index(const raft::resources& res, index* index); */ } // namespace helpers -} // namespace cuvs::neighbors::ivf_flat \ No newline at end of file +} // namespace cuvs::neighbors::ivf_flat diff --git a/cpp/include/cuvs/neighbors/ivf_pq.hpp b/cpp/include/cuvs/neighbors/ivf_pq.hpp index d7696a37b9..ea9e8b0aee 100644 --- a/cpp/include/cuvs/neighbors/ivf_pq.hpp +++ b/cpp/include/cuvs/neighbors/ivf_pq.hpp @@ -1308,26 +1308,26 @@ void search_with_filtering( * @{ */ /** - * Serialize the index to an output string. + * Write the index to an output stream * * @code{.cpp} * #include * * raft::resources handle; * - * // create an input string - * std::string str + * // create an output stream + * std::ostream os(std::cout.rdbuf()); * // create an index with `auto index = ivf_pq::build(...);` - * cuvs::serialize(handle, str, index); + * cuvs::neighbors::ivf_pq::serialize(handle, os, index); * @endcode * * @param[in] handle the raft handle - * @param[out] str output string + * @param[in] os output stream * @param[in] index IVF-PQ index * */ void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::ivf_pq::index& index); /** @@ -1341,7 +1341,7 @@ void serialize(raft::resources const& handle, * // create a string with a filepath * std::string filename("/path/to/index"); * // create an index with `auto index = ivf_pq::build(...);` - * cuvs::serialize(handle, filename, index); + * cuvs::neighbors::ivf_pq::serialize(handle, filename, index); * @endcode * * @param[in] handle the raft handle @@ -1349,25 +1349,26 @@ void serialize(raft::resources const& handle, * @param[in] index IVF-PQ index * */ -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::ivf_pq::index& index); +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_pq::index& index); /** - * Load index from input string. + * Load index from input stream * * @code{.cpp} * #include * * raft::resources handle; * - * std::string str = ... + * // create an input stream + * std::istream is(std::cin.rdbuf()); * * using IdxT = int64_t; // type of the index * // create an empty index * cuvs::neighbors::ivf_pq::index index(handl, index_params, dim); * - * cuvs::deserialize(handle, filename, &index); + * cuvs::neighbors::ivf_pq::deserialize(handle, is, index); * @endcode * * @param[in] handle the raft handle @@ -1375,9 +1376,8 @@ void serialize_file(raft::resources const& handle, * @param[out] index IVF-PQ index * */ - void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& str, cuvs::neighbors::ivf_pq::index* index); /** * Load index from file. @@ -1393,7 +1393,7 @@ void deserialize(raft::resources const& handle, * // create an empty index with * ivf_pq::index index(handle, index_params, dim); * - * cuvs::deserialize(handle, filename, &index); + * cuvs::neighbors::ivf_pq::deserialize(handle, filename, &index); * @endcode * * @param[in] handle the raft handle @@ -1401,9 +1401,9 @@ void deserialize(raft::resources const& handle, * @param[out] index IVF-PQ index * */ -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::ivf_pq::index* index); +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_pq::index* index); /** * @} */ diff --git a/cpp/src/neighbors/cagra_c.cpp b/cpp/src/neighbors/cagra_c.cpp index df0c4d9f25..868b3dec02 100644 --- a/cpp/src/neighbors/cagra_c.cpp +++ b/cpp/src/neighbors/cagra_c.cpp @@ -127,8 +127,7 @@ void _serialize(cuvsResources_t res, { auto res_ptr = reinterpret_cast(res); auto index_ptr = reinterpret_cast*>(index->addr); - cuvs::neighbors::cagra::serialize_file( - *res_ptr, std::string(filename), *index_ptr, include_dataset); + cuvs::neighbors::cagra::serialize(*res_ptr, std::string(filename), *index_ptr, include_dataset); } template @@ -136,7 +135,7 @@ void* _deserialize(cuvsResources_t res, const char* filename) { auto res_ptr = reinterpret_cast(res); auto index = new cuvs::neighbors::cagra::index(*res_ptr); - cuvs::neighbors::cagra::deserialize_file(*res_ptr, std::string(filename), index); + cuvs::neighbors::cagra::deserialize(*res_ptr, std::string(filename), index); return index; } diff --git a/cpp/src/neighbors/cagra_serialize.cuh b/cpp/src/neighbors/cagra_serialize.cuh index 100453b203..03f128cb99 100644 --- a/cpp/src/neighbors/cagra_serialize.cuh +++ b/cpp/src/neighbors/cagra_serialize.cuh @@ -19,82 +19,6 @@ #include "detail/cagra/cagra_serialize.cuh" namespace cuvs::neighbors::cagra { - -/** - * \defgroup cagra_serialize CAGRA Serialize - * @{ - */ - -/** - * Write the index to an output stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an output stream - * std::ostream os(std::cout.rdbuf()); - * // create an index with `auto index = raft::cagra::build(...);` - * raft::cagra::serialize(handle, os, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] os output stream - * @param[in] index CAGRA index - * @param[in] include_dataset Whether or not to write out the dataset to the file. - * - */ -template -void serialize(raft::resources const& handle, - std::ostream& os, - const index& index, - bool include_dataset = true) -{ - detail::serialize(handle, os, index, include_dataset); -} - -/** - * Save the index to file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * // create an index with `auto index = raft::cagra::build(...);` - * raft::cagra::serialize(handle, filename, index); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index - * @param[in] index CAGRA index - * @param[in] include_dataset Whether or not to write out the dataset to the file. - * - */ -template -void serialize(raft::resources const& handle, - const std::string& filename, - const index& index, - bool include_dataset = true) -{ - detail::serialize(handle, filename, index, include_dataset); -} - /** * Write the CAGRA built index as a base layer HNSW index to an output stream * @@ -161,70 +85,36 @@ void serialize_to_hnswlib(raft::resources const& handle, detail::serialize_to_hnswlib(handle, filename, index); } -/** - * Load index from input stream - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create an input stream - * std::istream is(std::cin.rdbuf()); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::cagra::deserialize(handle, is); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] is input stream - * - * @return cuvs::neighbors::experimental::cagra::index - */ -template -index deserialize(raft::resources const& handle, std::istream& is) -{ - return detail::deserialize(handle, is); -} - -/** - * Load index from file. - * - * Experimental, both the API and the serialization format are subject to change. - * - * @code{.cpp} - * #include - * #include - * - * raft::resources handle; - * - * // create a string with a filepath - * std::string filename("/path/to/index"); - * using T = float; // data element type - * using IdxT = int; // type of the index - * auto index = raft::cagra::deserialize(handle, filename); - * @endcode - * - * @tparam T data element type - * @tparam IdxT type of the indices - * - * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index - * - * @return cuvs::neighbors::experimental::cagra::index - */ -template -index deserialize(raft::resources const& handle, const std::string& filename) -{ - return detail::deserialize(handle, filename); -} - -/**@}*/ +#define CUVS_INST_CAGRA_SERIALIZE(DTYPE) \ + void serialize(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + cuvs::neighbors::cagra::detail::serialize( \ + handle, filename, index, include_dataset); \ + }; \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::cagra::index* index) \ + { \ + cuvs::neighbors::cagra::detail::deserialize(handle, filename, index); \ + }; \ + void serialize(raft::resources const& handle, \ + std::ostream& os, \ + const cuvs::neighbors::cagra::index& index, \ + bool include_dataset) \ + { \ + cuvs::neighbors::cagra::detail::serialize( \ + handle, os, index, include_dataset); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + std::istream& is, \ + cuvs::neighbors::cagra::index* index) \ + { \ + cuvs::neighbors::cagra::detail::deserialize(handle, is, index); \ + } } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/cagra_serialize_float.cu b/cpp/src/neighbors/cagra_serialize_float.cu index edaae0a0a9..b04fb51f36 100644 --- a/cpp/src/neighbors/cagra_serialize_float.cu +++ b/cpp/src/neighbors/cagra_serialize_float.cu @@ -15,68 +15,9 @@ */ #include "cagra_serialize.cuh" -#include - -#include - -#include - -#include -#include namespace cuvs::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - cuvs::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = cuvs::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - cuvs::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void serialize_to_hnswlib_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::cagra::index& index) \ - { \ - cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ - }; \ - void serialize_to_hnswlib(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::cagra::index& index) \ - { \ - std::stringstream os; \ - cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - cuvs::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = cuvs::neighbors::cagra::deserialize(handle, is); \ - } - -RAFT_INST_CAGRA_SERIALIZE(float); +CUVS_INST_CAGRA_SERIALIZE(float); -#undef RAFT_INST_CAGRA_SERIALIZE } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/cagra_serialize_int8.cu b/cpp/src/neighbors/cagra_serialize_int8.cu index bb23e0ae55..5c290bb675 100644 --- a/cpp/src/neighbors/cagra_serialize_int8.cu +++ b/cpp/src/neighbors/cagra_serialize_int8.cu @@ -15,68 +15,9 @@ */ #include "cagra_serialize.cuh" -#include - -#include - -#include - -#include -#include namespace cuvs::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - cuvs::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = cuvs::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - cuvs::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void serialize_to_hnswlib_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::cagra::index& index) \ - { \ - cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ - }; \ - void serialize_to_hnswlib(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::cagra::index& index) \ - { \ - std::stringstream os; \ - cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - cuvs::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = cuvs::neighbors::cagra::deserialize(handle, is); \ - } - -RAFT_INST_CAGRA_SERIALIZE(int8_t); +CUVS_INST_CAGRA_SERIALIZE(int8_t); -#undef RAFT_INST_CAGRA_SERIALIZE } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/cagra_serialize_uint8.cu b/cpp/src/neighbors/cagra_serialize_uint8.cu index b59b2cbe88..4494c9fbba 100644 --- a/cpp/src/neighbors/cagra_serialize_uint8.cu +++ b/cpp/src/neighbors/cagra_serialize_uint8.cu @@ -15,68 +15,9 @@ */ #include "cagra_serialize.cuh" -#include - -#include - -#include - -#include -#include namespace cuvs::neighbors::cagra { -#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - cuvs::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ - }; \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::cagra::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = cuvs::neighbors::cagra::deserialize(handle, filename); \ - }; \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::cagra::index& index, \ - bool include_dataset) \ - { \ - std::stringstream os; \ - cuvs::neighbors::cagra::serialize(handle, os, index, include_dataset); \ - str = os.str(); \ - } \ - \ - void serialize_to_hnswlib_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::cagra::index& index) \ - { \ - cuvs::neighbors::cagra::serialize_to_hnswlib(handle, filename, index); \ - }; \ - void serialize_to_hnswlib(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::cagra::index& index) \ - { \ - std::stringstream os; \ - cuvs::neighbors::cagra::serialize_to_hnswlib(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - cuvs::neighbors::cagra::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = cuvs::neighbors::cagra::deserialize(handle, is); \ - } - -RAFT_INST_CAGRA_SERIALIZE(uint8_t); +CUVS_INST_CAGRA_SERIALIZE(uint8_t); -#undef RAFT_INST_CAGRA_SERIALIZE } // namespace cuvs::neighbors::cagra diff --git a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh index 4947d3148b..41329975b3 100644 --- a/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/src/neighbors/detail/cagra/cagra_serialize.cuh @@ -230,7 +230,7 @@ void serialize_to_hnswlib(raft::resources const& res, * */ template -auto deserialize(raft::resources const& res, std::istream& is) -> index +void deserialize(raft::resources const& res, std::istream& is, index* index_) { raft::common::nvtx::range fun_scope("cagra::deserialize"); @@ -249,26 +249,23 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto graph = raft::make_host_matrix(n_rows, graph_degree); deserialize_mdspan(res, is, graph.view()); - index idx(res, metric); - idx.update_graph(res, raft::make_const_mdspan(graph.view())); + *index_ = index(res, metric); + index_->update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = raft::deserialize_scalar(res, is); if (has_dataset) { - idx.update_dataset(res, cuvs::neighbors::detail::deserialize_dataset(res, is)); + index_->update_dataset(res, cuvs::neighbors::detail::deserialize_dataset(res, is)); } - return idx; } template -auto deserialize(raft::resources const& res, const std::string& filename) -> index +void deserialize(raft::resources const& res, const std::string& filename, index* index_) { std::ifstream is(filename, std::ios::in | std::ios::binary); if (!is) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - auto index = detail::deserialize(res, is); + detail::deserialize(res, is, index_); is.close(); - - return index; } } // namespace cuvs::neighbors::cagra::detail diff --git a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py index 1733ca8b25..e739bddd47 100644 --- a/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py +++ b/cpp/src/neighbors/ivf_flat/generate_ivf_flat.py @@ -165,40 +165,7 @@ } """ -serialize_macro = """ -#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \\ - void serialize_file(raft::resources const& handle, \\ - const std::string& filename, \\ - const cuvs::neighbors::ivf_flat::index& index) \\ - { \\ - cuvs::neighbors::ivf_flat::detail::serialize(handle, filename, index); \\ - } \\ - \\ - void serialize(raft::resources const& handle, \\ - std::string& str, \\ - const cuvs::neighbors::ivf_flat::index& index) \\ - { \\ - std::ostringstream os; \\ - cuvs::neighbors::ivf_flat::detail::serialize(handle, os, index); \\ - str = os.str(); \\ - } \\ - \\ - void deserialize_file(raft::resources const& handle, \\ - const std::string& filename, \\ - cuvs::neighbors::ivf_flat::index* index) \\ - { \\ - * index = cuvs::neighbors::ivf_flat::detail::deserialize( \\ - handle, filename); \\ - } \\ - void deserialize(raft::resources const& handle, \\ - const std::string& str, \\ - cuvs::neighbors::ivf_flat::index* index) \\ - { \\ - std::istringstream is(str); \\ - * index = cuvs::neighbors::ivf_flat::detail::deserialize( \\ - handle, is); \\ - } -""" +serialize_macro = "" macros = dict( build_extend=dict( diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh index d88e173230..fdf06b286d 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize.cuh @@ -174,3 +174,31 @@ auto deserialize(raft::resources const& handle, const std::string& filename) -> return index; } } // namespace cuvs::neighbors::ivf_flat::detail + +#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ + void serialize(raft::resources const& handle, \ + const std::string& filename, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + cuvs::neighbors::ivf_flat::detail::serialize(handle, filename, index); \ + } \ + \ + void serialize(raft::resources const& handle, \ + std::ostream& os, \ + const cuvs::neighbors::ivf_flat::index& index) \ + { \ + cuvs::neighbors::ivf_flat::detail::serialize(handle, os, index); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& filename, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, filename); \ + } \ + void deserialize(raft::resources const& handle, \ + std::istream& is, \ + cuvs::neighbors::ivf_flat::index* index) \ + { \ + *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, is); \ + } diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu index 9ab00623ad..95288d1198 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_float_int64_t.cu @@ -28,37 +28,6 @@ #include "ivf_flat_serialize.cuh" namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::ivf_flat::index& index) \ - { \ - cuvs::neighbors::ivf_flat::detail::serialize(handle, filename, index); \ - } \ - \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::ivf_flat::index& index) \ - { \ - std::ostringstream os; \ - cuvs::neighbors::ivf_flat::detail::serialize(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::ivf_flat::index* index) \ - { \ - *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, filename); \ - } \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - cuvs::neighbors::ivf_flat::index* index) \ - { \ - std::istringstream is(str); \ - *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, is); \ - } CUVS_INST_IVF_FLAT_SERIALIZE(float, int64_t); #undef CUVS_INST_IVF_FLAT_SERIALIZE diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu index 18d6a02873..74e0c2653b 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_int8_t_int64_t.cu @@ -28,37 +28,6 @@ #include "ivf_flat_serialize.cuh" namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::ivf_flat::index& index) \ - { \ - cuvs::neighbors::ivf_flat::detail::serialize(handle, filename, index); \ - } \ - \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::ivf_flat::index& index) \ - { \ - std::ostringstream os; \ - cuvs::neighbors::ivf_flat::detail::serialize(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::ivf_flat::index* index) \ - { \ - *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, filename); \ - } \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - cuvs::neighbors::ivf_flat::index* index) \ - { \ - std::istringstream is(str); \ - *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, is); \ - } CUVS_INST_IVF_FLAT_SERIALIZE(int8_t, int64_t); #undef CUVS_INST_IVF_FLAT_SERIALIZE diff --git a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu index c5ab7d5c12..d33be5968d 100644 --- a/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu +++ b/cpp/src/neighbors/ivf_flat/ivf_flat_serialize_uint8_t_int64_t.cu @@ -28,37 +28,6 @@ #include "ivf_flat_serialize.cuh" namespace cuvs::neighbors::ivf_flat { - -#define CUVS_INST_IVF_FLAT_SERIALIZE(T, IdxT) \ - void serialize_file(raft::resources const& handle, \ - const std::string& filename, \ - const cuvs::neighbors::ivf_flat::index& index) \ - { \ - cuvs::neighbors::ivf_flat::detail::serialize(handle, filename, index); \ - } \ - \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const cuvs::neighbors::ivf_flat::index& index) \ - { \ - std::ostringstream os; \ - cuvs::neighbors::ivf_flat::detail::serialize(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize_file(raft::resources const& handle, \ - const std::string& filename, \ - cuvs::neighbors::ivf_flat::index* index) \ - { \ - *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, filename); \ - } \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - cuvs::neighbors::ivf_flat::index* index) \ - { \ - std::istringstream is(str); \ - *index = cuvs::neighbors::ivf_flat::detail::deserialize(handle, is); \ - } CUVS_INST_IVF_FLAT_SERIALIZE(uint8_t, int64_t); #undef CUVS_INST_IVF_FLAT_SERIALIZE diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_deserialize.cu b/cpp/src/neighbors/ivf_pq/ivf_pq_deserialize.cu index 7827e7892a..3cde5d11c6 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_deserialize.cu +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_deserialize.cu @@ -20,20 +20,19 @@ namespace cuvs::neighbors::ivf_pq { -void deserialize_file(raft::resources const& handle, - const std::string& filename, - cuvs::neighbors::ivf_pq::index* index) +void deserialize(raft::resources const& handle, + const std::string& filename, + cuvs::neighbors::ivf_pq::index* index) { if (!index) { RAFT_FAIL("Invalid index pointer"); } *index = cuvs::neighbors::ivf_pq::detail::deserialize(handle, filename); } void deserialize(raft::resources const& handle, - const std::string& str, + std::istream& is, cuvs::neighbors::ivf_pq::index* index) { if (!index) { RAFT_FAIL("Invalid index pointer"); } - std::istringstream is(str); *index = cuvs::neighbors::ivf_pq::detail::deserialize(handle, is); } -} // namespace cuvs::neighbors::ivf_pq \ No newline at end of file +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cu b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cu index f0214f4bb3..7b647fd24e 100644 --- a/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cu +++ b/cpp/src/neighbors/ivf_pq/ivf_pq_serialize.cu @@ -20,19 +20,17 @@ namespace cuvs::neighbors::ivf_pq { -void serialize_file(raft::resources const& handle, - const std::string& filename, - const cuvs::neighbors::ivf_pq::index& index) +void serialize(raft::resources const& handle, + const std::string& filename, + const cuvs::neighbors::ivf_pq::index& index) { cuvs::neighbors::ivf_pq::detail::serialize(handle, filename, index); } void serialize(raft::resources const& handle, - std::string& str, + std::ostream& os, const cuvs::neighbors::ivf_pq::index& index) { - std::ostringstream os; cuvs::neighbors::ivf_pq::detail::serialize(handle, os, index); - str = os.str(); } -} // namespace cuvs::neighbors::ivf_pq \ No newline at end of file +} // namespace cuvs::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index ad7f327859..63ac06911d 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -282,11 +282,11 @@ class AnnCagraTest : public ::testing::TestWithParam { index = cagra::build(handle_, index_params, database_view); }; - cagra::serialize_file(handle_, "cagra_index", index, ps.include_serialized_dataset); + cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset); } cagra::index index(handle_); - cagra::deserialize_file(handle_, "cagra_index", &index); + cagra::deserialize(handle_, "cagra_index", &index); if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } diff --git a/cpp/test/neighbors/ann_ivf_flat.cuh b/cpp/test/neighbors/ann_ivf_flat.cuh index 9ce23cc7df..49cb3ec2a1 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cuh +++ b/cpp/test/neighbors/ann_ivf_flat.cuh @@ -184,9 +184,9 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); const std::string filename = "ivf_flat_index"; - cuvs::neighbors::ivf_flat::serialize_file(handle_, filename, index_2); + cuvs::neighbors::ivf_flat::serialize(handle_, filename, index_2); cuvs::neighbors::ivf_flat::index index_loaded(handle_, index_params, ps.dim); - cuvs::neighbors::ivf_flat::deserialize_file(handle_, filename, &index_loaded); + cuvs::neighbors::ivf_flat::deserialize(handle_, filename, &index_loaded); ASSERT_EQ(index_2.size(), index_loaded.size()); cuvs::neighbors::ivf_flat::search(handle_, diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index f716a8efed..662b701af9 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -241,9 +241,9 @@ class ivf_pq_test : public ::testing::TestWithParam { auto build_serialize() { std::string filename = "ivf_pq_index"; - cuvs::neighbors::ivf_pq::serialize_file(handle_, filename, build_only()); + cuvs::neighbors::ivf_pq::serialize(handle_, filename, build_only()); cuvs::neighbors::ivf_pq::index index(handle_, ps.index_params, ps.dim); - cuvs::neighbors::ivf_pq::deserialize_file(handle_, filename, &index); + cuvs::neighbors::ivf_pq::deserialize(handle_, filename, &index); return index; }