From 96578a1718c2b2269380b5fd0076b8e9003d6eb5 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Tue, 3 Jan 2023 21:36:50 +0100 Subject: [PATCH] Serialization of IVF Flat and IVF PQ (#919) This PR implements serialization to file for `ivf_pq::index` and `ivf_flat::index` structures. Index building takes time, therefore downstream projects (like cuML) want to save the index (https://github.com/rapidsai/cuml/issues/4743). But downstream project should not depend on the implementation details of the index, therefore RAFT provides methods to serialize and deserialize the index. This is still experimental: - ideally we want to use a general serialization method for mdspan https://github.com/rapidsai/raft/pull/770, - instead of directly saving to file, raft should provide a byte string and let the downstream project decide how to save it (e.g. pickle for cuML). Python wrappers are provided for IVF-PQ to save/load the index. Authors: - Tamas Bela Feher (https://github.com/tfeher) - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/919 --- .../spatial/knn/detail/ann_serialization.h | 140 ++++++++++++++++++ .../spatial/knn/detail/ivf_flat_build.cuh | 95 ++++++++++++ .../raft/spatial/knn/detail/ivf_pq_build.cuh | 106 +++++++++++++ cpp/include/raft_runtime/neighbors/ivf_pq.hpp | 28 ++++ cpp/src/distance/neighbors/ivfpq_build.cu | 14 ++ cpp/test/neighbors/ann_ivf_flat.cu | 8 +- cpp/test/neighbors/ann_ivf_pq.cuh | 6 +- .../pylibraft/neighbors/ivf_pq/__init__.py | 22 ++- .../neighbors/ivf_pq/cpp/c_ivf_pq.pxd | 9 ++ .../pylibraft/neighbors/ivf_pq/ivf_pq.pyx | 105 +++++++++++++ .../pylibraft/pylibraft/test/test_ivf_pq.py | 48 ++++++ 11 files changed, 577 insertions(+), 4 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/detail/ann_serialization.h diff --git a/cpp/include/raft/spatial/knn/detail/ann_serialization.h b/cpp/include/raft/spatial/knn/detail/ann_serialization.h new file mode 100644 index 0000000000..cf2aeedcfc --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ann_serialization.h @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2022, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace raft::spatial::knn::detail { + +template +void write_scalar(std::ofstream& of, const T& value) +{ + of.write((char*)&value, sizeof value); + if (of.good()) { + RAFT_LOG_DEBUG("Written %z bytes", (sizeof value)); + } else { + RAFT_FAIL("error writing value to file"); + } +} + +template +T read_scalar(std::ifstream& file) +{ + T value; + file.read((char*)&value, sizeof value); + if (file.good()) { + RAFT_LOG_DEBUG("Read %z bytes", (sizeof value)); + } else { + RAFT_FAIL("error reading value from file"); + } + return value; +} + +template +void write_mdspan( + const raft::handle_t& handle, + std::ofstream& of, + const raft::device_mdspan& obj) +{ + using obj_t = raft::device_mdspan; + write_scalar(of, obj.rank()); + if (obj.is_exhaustive() && obj.is_unique()) { + write_scalar(of, obj.size()); + } else { + RAFT_FAIL("Cannot serialize non exhaustive mdarray"); + } + if (obj.size() > 0) { + for (typename obj_t::rank_type i = 0; i < obj.rank(); i++) + write_scalar(of, obj.extent(i)); + cudaStream_t stream = handle.get_stream(); + std::vector< + typename raft::device_mdspan::value_type> + tmp(obj.size()); + raft::update_host(tmp.data(), obj.data_handle(), obj.size(), stream); + handle.sync_stream(stream); + of.write(reinterpret_cast(tmp.data()), tmp.size() * sizeof(ElementType)); + if (of.good()) { + RAFT_LOG_DEBUG("Written %zu bytes", + static_cast(obj.size() * sizeof(obj.data_handle()[0]))); + } else { + RAFT_FAIL("Error writing mdarray to file"); + } + } else { + RAFT_LOG_DEBUG("Skipping mdspand with zero size"); + } +} + +template +void read_mdspan(const raft::handle_t& handle, + std::ifstream& file, + raft::device_mdspan& obj) +{ + using obj_t = raft::device_mdspan; + auto rank = read_scalar(file); + if (obj.rank() != rank) { RAFT_FAIL("Incorrect rank while reading mdarray"); } + auto size = read_scalar(file); + if (obj.size() != size) { + RAFT_FAIL("Incorrect rank while reading mdarray %zu vs %zu", + static_cast(size), + static_cast(obj.size())); + } + if (obj.size() > 0) { + for (typename obj_t::rank_type i = 0; i < obj.rank(); i++) { + auto ex = read_scalar(file); + if (obj.extent(i) != ex) { + RAFT_FAIL("Incorrect extent while reading mdarray %d vs %d at %d", + static_cast(ex), + static_cast(obj.extent(i)), + static_cast(i)); + } + } + cudaStream_t stream = handle.get_stream(); + std::vector tmp(obj.size()); + file.read(reinterpret_cast(tmp.data()), tmp.size() * sizeof(ElementType)); + raft::update_device(obj.data_handle(), tmp.data(), tmp.size(), stream); + handle.sync_stream(stream); + if (file.good()) { + RAFT_LOG_DEBUG("Read %zu bytes", + static_cast(obj.size() * sizeof(obj.data_handle()[0]))); + } else { + RAFT_FAIL("error reading mdarray from file"); + } + } else { + RAFT_LOG_DEBUG("Skipping mdspand with zero size"); + } +} + +template +void read_mdspan(const raft::handle_t& handle, + std::ifstream& file, + raft::device_mdspan&& obj) +{ + read_mdspan(handle, file, obj); +} +} // namespace raft::spatial::knn::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index f08a97e0f7..e951d8fe5d 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -18,6 +18,7 @@ #include "../ivf_flat_types.hpp" #include "ann_kmeans_balanced.cuh" +#include "ann_serialization.h" #include "ann_utils.cuh" #include @@ -378,4 +379,98 @@ inline void fill_refinement_index(const handle_t& handle, refinement_index->veclen()); RAFT_CUDA_TRY(cudaPeekAtLastError()); } + +static const int serialization_version = 1; + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index_ IVF-Flat index + * + */ +template +void save(const handle_t& handle, const std::string& filename, const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); } + + RAFT_LOG_DEBUG( + "Saving IVF-PQ index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + write_scalar(of, serialization_version); + write_scalar(of, index_.size()); + write_scalar(of, index_.dim()); + write_scalar(of, index_.n_lists()); + write_scalar(of, index_.metric()); + write_scalar(of, index_.veclen()); + write_scalar(of, index_.adaptive_centers()); + write_mdspan(handle, of, index_.data()); + write_mdspan(handle, of, index_.indices()); + write_mdspan(handle, of, index_.list_sizes()); + write_mdspan(handle, of, index_.list_offsets()); + write_mdspan(handle, of, index_.centers()); + if (index_.center_norms()) { + bool has_norms = true; + write_scalar(of, has_norms); + write_mdspan(handle, of, *index_.center_norms()); + } else { + bool has_norms = false; + write_scalar(of, has_norms); + } + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } +} + +/** Load an index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index_ IVF-Flat index + * + */ +template +auto load(const handle_t& handle, const std::string& filename) -> index +{ + std::ifstream infile(filename, std::ios::in | std::ios::binary); + + if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); } + + auto ver = read_scalar(infile); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); + } + auto n_rows = read_scalar(infile); + auto dim = read_scalar(infile); + auto n_lists = read_scalar(infile); + auto metric = read_scalar(infile); + auto veclen = read_scalar(infile); + bool adaptive_centers = read_scalar(infile); + + index index_ = + raft::spatial::knn::ivf_flat::index(handle, metric, n_lists, adaptive_centers, dim); + + index_.allocate(handle, n_rows, metric == raft::distance::DistanceType::L2Expanded); + auto data = index_.data(); + read_mdspan(handle, infile, data); + read_mdspan(handle, infile, index_.indices()); + read_mdspan(handle, infile, index_.list_sizes()); + read_mdspan(handle, infile, index_.list_offsets()); + read_mdspan(handle, infile, index_.centers()); + bool has_norms = read_scalar(infile); + if (has_norms) { + if (!index_.center_norms()) { + RAFT_FAIL("Error inconsistent center norms"); + } else { + auto center_norms = *index_.center_norms(); + read_mdspan(handle, infile, center_norms); + } + } + infile.close(); + return index_; +} } // namespace raft::spatial::knn::ivf_flat::detail diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index 2a2fc1f10b..d718deeb57 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -17,6 +17,7 @@ #pragma once #include "ann_kmeans_balanced.cuh" +#include "ann_serialization.h" #include "ann_utils.cuh" #include @@ -1263,4 +1264,109 @@ inline auto build( }}(); } +static const int serialization_version = 1; + +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the file name for saving the index + * @param[in] index_ IVF-PQ index + * + */ +template +void save(const handle_t& handle_, const std::string& filename, const index& index_) +{ + std::ofstream of(filename, std::ios::out | std::ios::binary); + if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d", + static_cast(index_.size()), + static_cast(index_.dim()), + static_cast(index_.pq_dim()), + static_cast(index_.pq_bits())); + + write_scalar(of, serialization_version); + write_scalar(of, index_.size()); + write_scalar(of, index_.dim()); + write_scalar(of, index_.pq_bits()); + write_scalar(of, index_.pq_dim()); + + write_scalar(of, index_.metric()); + write_scalar(of, index_.codebook_kind()); + write_scalar(of, index_.n_lists()); + write_scalar(of, index_.n_nonempty_lists()); + + write_mdspan(handle_, of, index_.pq_centers()); + write_mdspan(handle_, of, index_.pq_dataset()); + write_mdspan(handle_, of, index_.indices()); + write_mdspan(handle_, of, index_.rotation_matrix()); + write_mdspan(handle_, of, index_.list_offsets()); + write_mdspan(handle_, of, index_.list_sizes()); + write_mdspan(handle_, of, index_.centers()); + write_mdspan(handle_, of, index_.centers_rot()); + + of.close(); + if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } + return; +} + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index_ IVF-PQ index + * + */ +template +auto load(const handle_t& handle_, const std::string& filename) -> index +{ + std::ifstream infile(filename, std::ios::in | std::ios::binary); + + if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } + + auto ver = read_scalar(infile); + if (ver != serialization_version) { + RAFT_FAIL("serialization version mismatch %d vs. %d", ver, serialization_version); + } + auto n_rows = read_scalar(infile); + auto dim = read_scalar(infile); + auto pq_bits = read_scalar(infile); + auto pq_dim = read_scalar(infile); + + auto metric = read_scalar(infile); + auto codebook_kind = read_scalar(infile); + auto n_lists = read_scalar(infile); + auto n_nonempty_lists = read_scalar(infile); + + RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d", + static_cast(n_rows), + static_cast(dim), + static_cast(pq_dim), + static_cast(pq_bits), + static_cast(n_lists)); + + auto index_ = raft::neighbors::ivf_pq::index( + handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, n_nonempty_lists); + index_.allocate(handle_, n_rows); + + read_mdspan(handle_, infile, index_.pq_centers()); + read_mdspan(handle_, infile, index_.pq_dataset()); + read_mdspan(handle_, infile, index_.indices()); + read_mdspan(handle_, infile, index_.rotation_matrix()); + read_mdspan(handle_, infile, index_.list_offsets()); + read_mdspan(handle_, infile, index_.list_sizes()); + read_mdspan(handle_, infile, index_.centers()); + read_mdspan(handle_, infile, index_.centers_rot()); + + infile.close(); + + return index_; +} + } // namespace raft::spatial::knn::ivf_pq::detail diff --git a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp index 13e3eac61f..7956e41497 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_pq.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_pq.hpp @@ -74,4 +74,32 @@ RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t) #undef RAFT_INST_BUILD_EXTEND +/** + * Save the index to file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the filename for saving the index + * @param[in] index IVF-PQ index + * + */ +void save(const handle_t& handle, + const std::string& filename, + const raft::neighbors::ivf_pq::index& index); + +/** + * Load index from file. + * + * Experimental, both the API and the serialization format are subject to change. + * + * @param[in] handle the raft handle + * @param[in] filename the name of the file that stores the index + * @param[in] index IVF-PQ index + * + */ +void load(const handle_t& handle, + const std::string& filename, + raft::neighbors::ivf_pq::index* index); + } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/src/distance/neighbors/ivfpq_build.cu b/cpp/src/distance/neighbors/ivfpq_build.cu index 20119b5178..bbcd0533f8 100644 --- a/cpp/src/distance/neighbors/ivfpq_build.cu +++ b/cpp/src/distance/neighbors/ivfpq_build.cu @@ -64,4 +64,18 @@ RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t); #undef RAFT_INST_BUILD_EXTEND +void save(const handle_t& handle, + const std::string& filename, + const raft::neighbors::ivf_pq::index& index) +{ + raft::spatial::knn::ivf_pq::detail::save(handle, filename, index); +}; + +void load(const handle_t& handle, + const std::string& filename, + raft::neighbors::ivf_pq::index* index) +{ + if (!index) { RAFT_FAIL("Invalid index pointer"); } + *index = raft::spatial::knn::ivf_pq::detail::load(handle, filename); +}; } // namespace raft::runtime::neighbors::ivf_pq diff --git a/cpp/test/neighbors/ann_ivf_flat.cu b/cpp/test/neighbors/ann_ivf_flat.cu index 1207b75a4a..3285bc3496 100644 --- a/cpp/test/neighbors/ann_ivf_flat.cu +++ b/cpp/test/neighbors/ann_ivf_flat.cu @@ -118,6 +118,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { database.data(), ps.num_db_vecs, ps.dim); + handle_.sync_stream(stream_); approx_knn_search(handle_, distances_ivfflat_dev.data(), @@ -187,8 +188,13 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { indices_ivfflat_dev.data(), ps.num_queries, ps.k); auto dists_out_view = raft::make_device_matrix_view( distances_ivfflat_dev.data(), ps.num_queries, ps.k); + raft::spatial::knn::ivf_flat::detail::save(handle_, "ivf_flat_index", index_2); + + auto index_loaded = + raft::spatial::knn::ivf_flat::detail::load(handle_, "ivf_flat_index"); + ivf_flat::search(handle_, - index_2, + index_loaded, search_queries_view, indices_out_view, dists_out_view, diff --git a/cpp/test/neighbors/ann_ivf_pq.cuh b/cpp/test/neighbors/ann_ivf_pq.cuh index 7c3ec044b1..353e8b65e5 100644 --- a/cpp/test/neighbors/ann_ivf_pq.cuh +++ b/cpp/test/neighbors/ann_ivf_pq.cuh @@ -205,7 +205,11 @@ class ivf_pq_test : public ::testing::TestWithParam { template void run(BuildIndex build_index) { - auto index = build_index(); + { + auto index = build_index(); + raft::spatial::knn::ivf_pq::detail::save(handle_, "ivf_pq_index", index); + } + auto index = raft::spatial::knn::ivf_pq::detail::load(handle_, "ivf_pq_index"); size_t queries_size = ps.num_queries * ps.k; std::vector indices_ivf_pq(queries_size); diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py b/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py index 559eb21fdf..3d604f829d 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/__init__.py @@ -13,6 +13,24 @@ # limitations under the License. # -from .ivf_pq import Index, IndexParams, SearchParams, build, extend, search +from .ivf_pq import ( + Index, + IndexParams, + SearchParams, + build, + extend, + load, + save, + search, +) -__all__ = ["Index", "IndexParams", "SearchParams", "build", "extend", "search"] +__all__ = [ + "Index", + "IndexParams", + "SearchParams", + "build", + "extend", + "load", + "save", + "search", +] diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd index 2f1354c475..1b8076487d 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/cpp/c_ivf_pq.pxd @@ -32,6 +32,7 @@ from libc.stdint cimport ( uintptr_t, ) from libcpp cimport bool, nullptr +from libcpp.string cimport string from rmm._lib.memory_resource cimport device_memory_resource @@ -175,3 +176,11 @@ cdef extern from "raft_runtime/neighbors/ivf_pq.hpp" \ uint64_t* neighbors, float* distances, device_memory_resource* mr) except + + + cdef void save(const handle_t& handle, + const string& filename, + const index[uint64_t]& index) except + + + cdef void load(const handle_t& handle, + const string& filename, + index[uint64_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx index 6ad9b753b3..a7137e4d08 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_pq/ivf_pq.pyx @@ -30,6 +30,7 @@ from libc.stdint cimport ( uintptr_t, ) from libcpp cimport bool, nullptr +from libcpp.string cimport string from pylibraft.distance.distance_type cimport DistanceType @@ -730,3 +731,107 @@ def search(SearchParams search_params, raise ValueError("query dtype %s not supported" % queries_dt) return (distances, neighbors) + + +@auto_sync_handle +def save(filename, Index index, handle=None): + """ + Saves the index to file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained IVF-PQ index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import Handle + >>> from pylibraft.neighbors import ivf_pq + + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build index + >>> handle = Handle() + >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset, handle=handle) + >>> ivf_pq.save("my_index.bin", index, handle=handle) + """ + if not index.trained: + raise ValueError("Index need to be built before saving it.") + + if handle is None: + handle = Handle() + cdef handle_t* handle_ = handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + + c_ivf_pq.save(deref(handle_), c_filename, deref(index.index)) + + +@auto_sync_handle +def load(filename, handle=None): + """ + Loads index from file. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of raft is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + {handle_docstring} + + Returns + ------- + index : Index + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import Handle + >>> from pylibraft.neighbors import ivf_pq + + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build and save index + >>> handle = Handle() + >>> index = ivf_pq.build(ivf_pq.IndexParams(), dataset, handle=handle) + >>> ivf_pq.save("my_index.bin", index, handle=handle) + >>> del index + + >>> n_queries = 100 + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> handle = Handle() + >>> index = ivf_pq.load("my_index.bin", handle=handle) + + >>> distances, neighbors = ivf_pq.search(ivf_pq.SearchParams(), index, + ... queries, k=10, handle=handle) + """ + if handle is None: + handle = Handle() + cdef handle_t* handle_ = handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + index = Index() + + c_ivf_pq.load(deref(handle_), c_filename, index.index) + index.trained = True + + return index diff --git a/python/pylibraft/pylibraft/test/test_ivf_pq.py b/python/pylibraft/pylibraft/test/test_ivf_pq.py index 4c102873d1..2c6e0dd14c 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_pq.py +++ b/python/pylibraft/pylibraft/test/test_ivf_pq.py @@ -483,3 +483,51 @@ def test_search_inputs(params): out_idx_device, out_dist_device, ) + + +def test_save_load(): + n_rows = 10000 + n_cols = 50 + n_queries = 1000 + dtype = np.float32 + + dataset = generate_data((n_rows, n_cols), dtype) + dataset_device = device_ndarray(dataset) + + build_params = ivf_pq.IndexParams(n_lists=100, metric="l2_expanded") + index = ivf_pq.build(build_params, dataset_device) + + assert index.trained + filename = "my_index.bin" + ivf_pq.save(filename, index) + loaded_index = ivf_pq.load(filename) + + assert index.pq_dim == loaded_index.pq_dim + assert index.pq_bits == loaded_index.pq_bits + assert index.metric == loaded_index.metric + assert index.n_lists == loaded_index.n_lists + assert index.size == loaded_index.size + + queries = generate_data((n_queries, n_cols), dtype) + + queries_device = device_ndarray(queries) + search_params = ivf_pq.SearchParams(n_probes=100) + k = 10 + + distance_dev, neighbors_dev = ivf_pq.search( + search_params, index, queries_device, k + ) + + neighbors = neighbors_dev.copy_to_host() + dist = distance_dev.copy_to_host() + del index + + distance_dev, neighbors_dev = ivf_pq.search( + search_params, loaded_index, queries_device, k + ) + + neighbors2 = neighbors_dev.copy_to_host() + dist2 = distance_dev.copy_to_host() + + assert np.all(neighbors == neighbors2) + assert np.allclose(dist, dist2, rtol=1e-6)