From b266c4aa68106ca636872063f2c0ffca137c13f2 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 6 Mar 2023 10:04:58 -0800 Subject: [PATCH 1/4] add overloads --- .../neighbors/detail/ivf_pq_serialize.cuh | 148 +++++++++++------- 1 file changed, 93 insertions(+), 55 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 33d9b363ba..0410b5465a 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -48,93 +48,108 @@ struct check_index_layout { template struct check_index_layout), 536>; /** - * Save the index to file. + * Write the index to an output stream * * Experimental, both the API and the serialization format are subject to change. * * @param[in] handle the raft handle - * @param[in] filename the file name for saving the index + * @param[in] os output stream * @param[in] index IVF-PQ index * */ template void serialize(raft::device_resources const& handle_, - const std::string& filename, - const index& index) -{ - std::ofstream of(filename, std::ios::out | std::ios::binary); - if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + std::ostream& os, + const index& index) { RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", - static_cast(index.size()), - static_cast(index.dim()), - static_cast(index.pq_dim()), - static_cast(index.pq_bits())); - - serialize_scalar(handle_, of, kSerializationVersion); - serialize_scalar(handle_, of, index.size()); - serialize_scalar(handle_, of, index.dim()); - serialize_scalar(handle_, of, index.pq_bits()); - serialize_scalar(handle_, of, index.pq_dim()); - serialize_scalar(handle_, of, index.conservative_memory_allocation()); - - serialize_scalar(handle_, of, index.metric()); - serialize_scalar(handle_, of, index.codebook_kind()); - serialize_scalar(handle_, of, index.n_lists()); - - serialize_mdspan(handle_, of, index.pq_centers()); - serialize_mdspan(handle_, of, index.centers()); - serialize_mdspan(handle_, of, index.centers_rot()); - serialize_mdspan(handle_, of, index.rotation_matrix()); + static_cast(index.size()), + static_cast(index.dim()), + static_cast(index.pq_dim()), + static_cast(index.pq_bits())); + + serialize_scalar(handle_, os, kSerializationVersion); + serialize_scalar(handle_, os, index.size()); + serialize_scalar(handle_, os, index.dim()); + serialize_scalar(handle_, os, index.pq_bits()); + serialize_scalar(handle_, os, index.pq_dim()); + serialize_scalar(handle_, os, index.conservative_memory_allocation()); + + serialize_scalar(handle_, os, index.metric()); + serialize_scalar(handle_, os, index.codebook_kind()); + serialize_scalar(handle_, os, index.n_lists()); + + serialize_mdspan(handle_, os, index.pq_centers()); + serialize_mdspan(handle_, os, index.centers()); + serialize_mdspan(handle_, os, index.centers_rot()); + serialize_mdspan(handle_, os, index.rotation_matrix()); auto sizes_host = make_host_mdarray(index.list_sizes().extents()); copy(sizes_host.data_handle(), - index.list_sizes().data_handle(), - sizes_host.size(), - handle_.get_stream()); + index.list_sizes().data_handle(), + sizes_host.size(), + handle_.get_stream()); handle_.sync_stream(); - serialize_mdspan(handle_, of, sizes_host.view()); + serialize_mdspan(handle_, os, sizes_host.view()); auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; for (uint32_t label = 0; label < index.n_lists(); label++) { ivf::serialize_list( - handle_, of, index.lists()[label], list_store_spec, sizes_host(label)); + handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); } +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-PQ index + * + */ +template +void serialize(raft::device_resources const& handle_, + const std::string& filename, + const index& index) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + serialize(handle_, of, index); + of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } return; } /** - * Load index from file. + * Load index from input stream * * Experimental, both the API and the serialization format are subject to change. * * @param[in] handle the raft handle - * @param[in] filename the name of the file that stores the index + * @param[in] is input stream * @param[in] index IVF-PQ index * */ template -auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index +auto deserialize(raft::device_resources const& handle_, std::istream& is) -> index { - std::ifstream infile(filename, std::ios::in | std::ios::binary); - - if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - - auto ver = deserialize_scalar(handle_, infile); + auto ver = deserialize_scalar(handle_, is); if (ver != kSerializationVersion) { RAFT_FAIL("serialization version mismatch %d vs. %d", ver, kSerializationVersion); } - auto n_rows = deserialize_scalar(handle_, infile); - auto dim = deserialize_scalar(handle_, infile); - auto pq_bits = deserialize_scalar(handle_, infile); - auto pq_dim = deserialize_scalar(handle_, infile); - auto cma = deserialize_scalar(handle_, infile); + auto n_rows = deserialize_scalar(handle_, is); + auto dim = deserialize_scalar(handle_, is); + auto pq_bits = deserialize_scalar(handle_, is); + auto pq_dim = deserialize_scalar(handle_, is); + auto cma = deserialize_scalar(handle_, is); - auto metric = deserialize_scalar(handle_, infile); - auto codebook_kind = deserialize_scalar(handle_, infile); - auto n_lists = deserialize_scalar(handle_, infile); + auto metric = deserialize_scalar(handle_, is); + auto codebook_kind = deserialize_scalar(handle_, is); + auto n_lists = deserialize_scalar(handle_, is); RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d", static_cast(n_rows), @@ -146,24 +161,47 @@ auto deserialize(raft::device_resources const& handle_, const std::string& filen auto index = raft::neighbors::ivf_pq::index( handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, cma); - deserialize_mdspan(handle_, infile, index.pq_centers()); - deserialize_mdspan(handle_, infile, index.centers()); - deserialize_mdspan(handle_, infile, index.centers_rot()); - deserialize_mdspan(handle_, infile, index.rotation_matrix()); - deserialize_mdspan(handle_, infile, index.list_sizes()); + deserialize_mdspan(handle_, is, index.pq_centers()); + deserialize_mdspan(handle_, is, index.centers()); + deserialize_mdspan(handle_, is, index.centers_rot()); + deserialize_mdspan(handle_, is, index.rotation_matrix()); + deserialize_mdspan(handle_, is, index.list_sizes()); auto list_device_spec = list_spec{pq_bits, pq_dim, cma}; auto list_store_spec = list_spec{pq_bits, pq_dim, true}; for (auto& list : index.lists()) { ivf::deserialize_list( - handle_, infile, list, list_store_spec, list_device_spec); + handle_, is, list, list_store_spec, list_device_spec); } handle_.sync_stream(); - infile.close(); recompute_internal_state(handle_, index); return index; } +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index IVF-PQ index + * + */ +template +auto deserialize(raft::device_resources const& handle_, const std::string& filename) -> index +{ + std::ifstream infile(filename, std::ios::in | std::ios::binary); + + if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto index = deserialize(handle_, infile); + + infile.close(); + + return index; +} + } // namespace raft::neighbors::ivf_pq::detail From 0b8994b3bf87f57c2532cd1edde2651d07e20603 Mon Sep 17 00:00:00 2001 From: divyegala Date: Mon, 6 Mar 2023 10:09:29 -0800 Subject: [PATCH 2/4] style format --- .../neighbors/detail/ivf_pq_serialize.cuh | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 0410b5465a..a15fdc30da 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -58,15 +58,13 @@ template struct check_index_layout), 536>; * */ template -void serialize(raft::device_resources const& handle_, - std::ostream& os, - const index& index) { - +void serialize(raft::device_resources const& handle_, std::ostream& os, const index& index) +{ RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", - static_cast(index.size()), - static_cast(index.dim()), - static_cast(index.pq_dim()), - static_cast(index.pq_bits())); + static_cast(index.size()), + static_cast(index.dim()), + static_cast(index.pq_dim()), + static_cast(index.pq_bits())); serialize_scalar(handle_, os, kSerializationVersion); serialize_scalar(handle_, os, index.size()); @@ -86,9 +84,9 @@ void serialize(raft::device_resources const& handle_, auto sizes_host = make_host_mdarray(index.list_sizes().extents()); copy(sizes_host.data_handle(), - index.list_sizes().data_handle(), - sizes_host.size(), - handle_.get_stream()); + index.list_sizes().data_handle(), + sizes_host.size(), + handle_.get_stream()); handle_.sync_stream(); serialize_mdspan(handle_, os, sizes_host.view()); auto list_store_spec = list_spec{index.pq_bits(), index.pq_dim(), true}; @@ -96,7 +94,6 @@ void serialize(raft::device_resources const& handle_, ivf::serialize_list( handle_, os, index.lists()[label], list_store_spec, sizes_host(label)); } - } /** From 49e783c49c34c38e8c83513aeb23083fbe240591 Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 7 Mar 2023 05:06:38 -0800 Subject: [PATCH 3/4] adding public API --- .../neighbors/detail/ivf_pq_serialize.cuh | 4 +- cpp/include/raft/neighbors/ivf_pq.cuh | 2 +- .../raft/neighbors/ivf_pq_serialize.cuh | 89 +++++++++++++++++++ .../distance/neighbors/ivfpq_deserialize.cu | 2 +- cpp/src/distance/neighbors/ivfpq_serialize.cu | 2 +- cpp/test/neighbors/ann_ivf_pq.cuh | 4 +- 6 files changed, 96 insertions(+), 7 deletions(-) create mode 100644 cpp/include/raft/neighbors/ivf_pq_serialize.cuh diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index a15fdc30da..77e32d1b1c 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -114,7 +114,7 @@ void serialize(raft::device_resources const& handle_, std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - serialize(handle_, of, index); + detail::serialize(handle_, of, index); of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } @@ -194,7 +194,7 @@ auto deserialize(raft::device_resources const& handle_, const std::string& filen if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - auto index = deserialize(handle_, infile); + auto index = detail::deserialize(handle_, infile); infile.close(); diff --git a/cpp/include/raft/neighbors/ivf_pq.cuh b/cpp/include/raft/neighbors/ivf_pq.cuh index e2cc3c4728..053fe634da 100644 --- a/cpp/include/raft/neighbors/ivf_pq.cuh +++ b/cpp/include/raft/neighbors/ivf_pq.cuh @@ -18,7 +18,7 @@ #include #include -#include +#include #include #include diff --git a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh new file mode 100644 index 0000000000..790683b5a8 --- /dev/null +++ b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "detail/ivf_pq_serialize.cuh" + +namespace raft::neighbors::ivf_pq { + +/** + * Write the index to an output stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] os output stream + * @param[in] index IVF-PQ index + * + */ +template +void serialize(raft::device_resources const& handle_, std::ostream& os, const index& index) +{ + detail::serialize(handle_, os, index); +} + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index IVF-PQ index + * + */ +template +void serialize(raft::device_resources const& handle_, + const std::string& filename, + const index& index) +{ + detail::serialize(handle_, filename, index); +} + +/** + * Load index from input stream + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] is input stream + * @param[in] index IVF-PQ index + * + */ +template +index deserialize(raft::device_resources const& handle_, std::istream& is) +{ + return detail::deserialize(handle_, is); +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index IVF-PQ index + * + */ +template +index deserialize(raft::device_resources const& handle_, const std::string& filename) +{ + return detail::deserialize(handle_, filename); +} + +} // namespace raft::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_deserialize.cu b/cpp/src/distance/neighbors/ivfpq_deserialize.cu index 403f80c9fc..8f71e5622b 100644 --- a/cpp/src/distance/neighbors/ivfpq_deserialize.cu +++ b/cpp/src/distance/neighbors/ivfpq_deserialize.cu @@ -24,6 +24,6 @@ void deserialize(raft::device_resources const& handle, raft::neighbors::ivf_pq::index* index) { if (!index) { RAFT_FAIL("Invalid index pointer"); } - *index = raft::neighbors::ivf_pq::detail::deserialize(handle, filename); + *index = raft::neighbors::ivf_pq::deserialize(handle, filename); }; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_serialize.cu b/cpp/src/distance/neighbors/ivfpq_serialize.cu index f6fd70be82..b7ceb9150a 100644 --- a/cpp/src/distance/neighbors/ivfpq_serialize.cu +++ b/cpp/src/distance/neighbors/ivfpq_serialize.cu @@ -23,7 +23,7 @@ void serialize(raft::device_resources const& handle, const std::string& filename, const raft::neighbors::ivf_pq::index& index) { - raft::neighbors::ivf_pq::detail::serialize(handle, filename, index); + raft::neighbors::ivf_pq::serialize(handle, filename, index); }; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index df295b8bcb..91294a859a 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -212,8 +212,8 @@ class ivf_pq_test : public ::testing::TestWithParam { auto build_serialize() { - ivf_pq::detail::serialize(handle_, "ivf_pq_index", build_only()); - return ivf_pq::detail::deserialize(handle_, "ivf_pq_index"); + ivf_pq::serialize(handle_, "ivf_pq_index", build_only()); + return ivf_pq::deserialize(handle_, "ivf_pq_index"); } template From 93c3899e78c84b1b003a0428ed841f6b5bbccd9b Mon Sep 17 00:00:00 2001 From: divyegala Date: Tue, 7 Mar 2023 05:48:21 -0800 Subject: [PATCH 4/4] add docs --- .../neighbors/detail/ivf_pq_serialize.cuh | 2 - .../raft/neighbors/ivf_pq_serialize.cuh | 81 ++++++++++++++++--- docs/source/cpp_api/neighbors_ivf_pq.rst | 14 ++++ 3 files changed, 85 insertions(+), 12 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh index 77e32d1b1c..0701b0feb5 100644 --- a/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_pq_serialize.cuh @@ -128,7 +128,6 @@ void serialize(raft::device_resources const& handle_, * * @param[in] handle the raft handle * @param[in] is input stream - * @param[in] index IVF-PQ index * */ template @@ -184,7 +183,6 @@ auto deserialize(raft::device_resources const& handle_, std::istream& is) -> ind * * @param[in] handle the raft handle * @param[in] filename the name of the file that stores the index - * @param[in] index IVF-PQ index * */ template diff --git a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh index 790683b5a8..9a719c69d4 100644 --- a/cpp/include/raft/neighbors/ivf_pq_serialize.cuh +++ b/cpp/include/raft/neighbors/ivf_pq_serialize.cuh @@ -20,20 +20,39 @@ namespace raft::neighbors::ivf_pq { +/** + * \defgroup ivf_pq_serialize IVF-PQ Serialize + * @{ + */ + /** * Write the index to an output stream * * Experimental, both the API and the serialization format are subject to change. * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an output stream + * std::ostream os(std::cout.rdbuf()); + * // create an index with `auto index = ivf_pq::build(...);` + * raft::serailize(handle, os, index); + * @endcode + * + * @tparam IdxT type of the index + * * @param[in] handle the raft handle * @param[in] os output stream * @param[in] index IVF-PQ index * + * @return raft::neighbors::ivf_pq::index */ template -void serialize(raft::device_resources const& handle_, std::ostream& os, const index& index) +void serialize(raft::device_resources const& handle, std::ostream& os, const index& index) { - detail::serialize(handle_, os, index); + detail::serialize(handle, os, index); } /** @@ -41,17 +60,31 @@ void serialize(raft::device_resources const& handle_, std::ostream& os, const in * * Experimental, both the API and the serialization format are subject to change. * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * // create an index with `auto index = ivf_pq::build(...);` + * raft::serailize(handle, filename, index); + * @endcode + * + * @tparam IdxT type of the index + * * @param[in] handle the raft handle * @param[in] filename the file name for saving the index * @param[in] index IVF-PQ index * + * @return raft::neighbors::ivf_pq::index */ template -void serialize(raft::device_resources const& handle_, +void serialize(raft::device_resources const& handle, const std::string& filename, const index& index) { - detail::serialize(handle_, filename, index); + detail::serialize(handle, filename, index); } /** @@ -59,15 +92,28 @@ void serialize(raft::device_resources const& handle_, * * Experimental, both the API and the serialization format are subject to change. * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create an input stream + * std::istream is(std::cin.rdbuf()); + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, is); + * @endcode + * + * @tparam IdxT type of the index + * * @param[in] handle the raft handle * @param[in] is input stream - * @param[in] index IVF-PQ index * + * @return raft::neighbors::ivf_pq::index */ template -index deserialize(raft::device_resources const& handle_, std::istream& is) +index deserialize(raft::device_resources const& handle, std::istream& is) { - return detail::deserialize(handle_, is); + return detail::deserialize(handle, is); } /** @@ -75,15 +121,30 @@ index deserialize(raft::device_resources const& handle_, std::istream& is) * * Experimental, both the API and the serialization format are subject to change. * + * @code{.cpp} + * #include + * + * raft::device_resources handle; + * + * // create a string with a filepath + * std::string filename("/path/to/index"); + * using IdxT = int; // type of the index + * auto index = raft::deserialize(handle, filename); + * @endcode + * + * @tparam IdxT type of the index + * * @param[in] handle the raft handle * @param[in] filename the name of the file that stores the index - * @param[in] index IVF-PQ index * + * @return raft::neighbors::ivf_pq::index */ template -index deserialize(raft::device_resources const& handle_, const std::string& filename) +index deserialize(raft::device_resources const& handle, const std::string& filename) { - return detail::deserialize(handle_, filename); + return detail::deserialize(handle, filename); } +/**@}*/ + } // namespace raft::neighbors::ivf_pq diff --git a/docs/source/cpp_api/neighbors_ivf_pq.rst b/docs/source/cpp_api/neighbors_ivf_pq.rst index d22ea6231f..228833983c 100644 --- a/docs/source/cpp_api/neighbors_ivf_pq.rst +++ b/docs/source/cpp_api/neighbors_ivf_pq.rst @@ -14,4 +14,18 @@ namespace *raft::neighbors::ivf_pq* :members: :content-only: +Serializer Methods +------------------ +``#include `` +.. doxygenfunction:: serialize(raft::device_resources const& handle, std::ostream& os, const index& index) + :project: RAFT + +.. doxygenfunction:: serialize(raft::device_resources const& handle, const std::string& filename, const index& index) + :project: RAFT + +.. doxygenfunction:: deserialize(raft::device_resources const& handle, std::istream& is) + :project: RAFT + +.. doxygenfunction:: deserialize(raft::device_resources const& handle, const std::string& filename) + :project: RAFT \ No newline at end of file