From af7e06760fe0ed8ca6c7c59dec5269559394da6d Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Fri, 19 May 2023 14:17:08 +0200 Subject: [PATCH] Python API for IVF-Flat serialization (#1516) This PR adds Python API for IVF-Flat serialization. closes #752 Authors: - Tamas Bela Feher (https://github.com/tfeher) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1516 --- cpp/CMakeLists.txt | 1 + .../neighbors/detail/ivf_flat_serialize.cuh | 10 +- .../raft_runtime/neighbors/ivf_flat.hpp | 17 +- .../neighbors/ivf_flat_serialize.cu | 65 ++++++++ .../pylibraft/neighbors/ivf_flat/__init__.py | 13 +- .../neighbors/ivf_flat/cpp/c_ivf_flat.pxd | 48 ++++++ .../pylibraft/neighbors/ivf_flat/ivf_flat.pyx | 150 ++++++++++++++++++ .../pylibraft/pylibraft/test/test_ivf_flat.py | 47 ++++++ 8 files changed, 348 insertions(+), 3 deletions(-) create mode 100644 cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 68ff4f3bb6..eb35554768 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -365,6 +365,7 @@ if(RAFT_COMPILE_LIBRARY) src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu src/raft_runtime/neighbors/ivf_flat_build.cu src/raft_runtime/neighbors/ivf_flat_search.cu + src/raft_runtime/neighbors/ivf_flat_serialize.cu src/raft_runtime/neighbors/ivfpq_build.cu src/raft_runtime/neighbors/ivfpq_deserialize.cu src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index af2e6ba0f8..b00d308827 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -33,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail { // backward compatibility. // TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward // compatible fashion. -constexpr int serialization_version = 3; +constexpr int serialization_version = 4; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -62,6 +63,10 @@ void serialize(raft::resources const& handle, std::ostream& os, const index(index_.size()), index_.dim()); + std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + serialize_scalar(handle, os, serialization_version); serialize_scalar(handle, os, index_.size()); serialize_scalar(handle, os, index_.dim()); @@ -123,6 +128,9 @@ void serialize(raft::resources const& handle, template auto deserialize(raft::resources const& handle, std::istream& is) -> index { + char dtype_string[4]; + is.read(dtype_string, 4); + auto ver = deserialize_scalar(handle, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); diff --git a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp index 37a9d39ae3..5b8918ec7f 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include namespace raft::runtime::neighbors::ivf_flat { @@ -43,7 +44,21 @@ namespace raft::runtime::neighbors::ivf_flat { void extend(raft::resources const& handle, \ raft::device_matrix_view new_vectors, \ std::optional> new_indices, \ - raft::neighbors::ivf_flat::index* idx); + raft::neighbors::ivf_flat::index* idx); \ + \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::ivf_flat::index& index); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::ivf_flat::index* index); \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::ivf_flat::index& index); \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::ivf_flat::index*); RAFT_INST_BUILD_EXTEND(float, int64_t) RAFT_INST_BUILD_EXTEND(int8_t, int64_t) diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu new file mode 100644 index 0000000000..049b8b00da --- /dev/null +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu @@ -0,0 +1,65 @@ +/* + * Copyright (c) 2023, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include +#include +#include +#include + +namespace raft::runtime::neighbors::ivf_flat { + +#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::ivf_flat::index& index) \ + { \ + raft::neighbors::ivf_flat::serialize(handle, filename, index); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::ivf_flat::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::ivf_flat::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::ivf_flat::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::ivf_flat::serialize(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::ivf_flat::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::ivf_flat::deserialize(handle, is); \ + } + +RAFT_IVF_FLAT_SERIALIZE_INST(float); +RAFT_IVF_FLAT_SERIALIZE_INST(int8_t); +RAFT_IVF_FLAT_SERIALIZE_INST(uint8_t); + +#undef RAFT_IVF_FLAT_SERIALIZE_INST +} // namespace raft::runtime::neighbors::ivf_flat diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py b/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py index 58fd88b873..057cb98f17 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py @@ -13,7 +13,16 @@ # limitations under the License. # -from .ivf_flat import Index, IndexParams, SearchParams, build, extend, search +from .ivf_flat import ( + Index, + IndexParams, + SearchParams, + build, + extend, + load, + save, + search, +) __all__ = [ "Index", @@ -22,4 +31,6 @@ "build", "extend", "search", + "save", + "load", ] diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd index 31a251e7c2..a281d33310 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -133,3 +133,51 @@ cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \ device_matrix_view[uint8_t, int64_t, row_major] queries, device_matrix_view[int64_t, int64_t, row_major] neighbors, device_matrix_view[float, int64_t, row_major] distances) except + + + cdef void serialize(const device_resources& handle, + string& str, + const index[float, int64_t]& index) except + + + cdef void deserialize(const device_resources& handle, + const string& str, + index[float, int64_t]* index) except + + + cdef void serialize(const device_resources& handle, + string& str, + const index[uint8_t, int64_t]& index) except + + + cdef void deserialize(const device_resources& handle, + const string& str, + index[uint8_t, int64_t]* index) except + + + cdef void serialize(const device_resources& handle, + string& str, + const index[int8_t, int64_t]& index) except + + + cdef void deserialize(const device_resources& handle, + const string& str, + index[int8_t, int64_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[float, int64_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[float, int64_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[uint8_t, int64_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[uint8_t, int64_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[int8_t, int64_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[int8_t, int64_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index 352376fe17..0e550547d3 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -708,3 +708,153 @@ def search(SearchParams search_params, raise ValueError("query dtype %s not supported" % queries_dt) return (distances, neighbors) + + +@auto_sync_handle +def save(filename, Index index, handle=None): + """ + Saves the index to file. + + Saving / loading the index is experimental. The serialization format is + subject to change. + + Parameters + ---------- + filename : string + Name of the file. + index : Index + Trained IVF-Flat index. + {handle_docstring} + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import ivf_flat + + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build index + >>> handle = DeviceResources() + >>> index = ivf_flat.build(ivf_flat.IndexParams(), dataset, handle=handle) + >>> ivf_flat.save("my_index.bin", index, handle=handle) + """ + if not index.trained: + raise ValueError("Index need to be built before saving it.") + + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + if index.active_index_type == "float32": + idx_float = index + c_ivf_flat.serialize_file( + deref(handle_), c_filename, deref(idx_float.index)) + elif index.active_index_type == "byte": + idx_int8 = index + c_ivf_flat.serialize_file( + deref(handle_), c_filename, deref(idx_int8.index)) + elif index.active_index_type == "ubyte": + idx_uint8 = index + c_ivf_flat.serialize_file( + deref(handle_), c_filename, deref(idx_uint8.index)) + else: + raise ValueError( + "Index dtype %s not supported" % index.active_index_type) + + +@auto_sync_handle +def load(filename, handle=None): + """ + Loads index from file. + + Saving / loading the index is experimental. The serialization format is + subject to change, therefore loading an index saved with a previous + version of raft is not guaranteed to work. + + Parameters + ---------- + filename : string + Name of the file. + {handle_docstring} + + Returns + ------- + index : Index + + Examples + -------- + >>> import cupy as cp + + >>> from pylibraft.common import DeviceResources + >>> from pylibraft.neighbors import ivf_flat + + >>> n_samples = 50000 + >>> n_features = 50 + >>> dataset = cp.random.random_sample((n_samples, n_features), + ... dtype=cp.float32) + + >>> # Build and save index + >>> handle = DeviceResources() + >>> index = ivf_flat.build(ivf_flat.IndexParams(), dataset, handle=handle) + >>> ivf_flat.save("my_index.bin", index, handle=handle) + >>> del index + + >>> n_queries = 100 + >>> queries = cp.random.random_sample((n_queries, n_features), + ... dtype=cp.float32) + >>> handle = DeviceResources() + >>> index = ivf_flat.load("my_index.bin", handle=handle) + + >>> distances, neighbors = ivf_flat.search(ivf_pq.SearchParams(), index, + ... queries, k=10, handle=handle) + """ + if handle is None: + handle = DeviceResources() + cdef device_resources* handle_ = \ + handle.getHandle() + + cdef string c_filename = filename.encode('utf-8') + cdef IndexFloat idx_float + cdef IndexInt8 idx_int8 + cdef IndexUint8 idx_uint8 + + with open(filename, 'rb') as f: + type_str = f.read(3).decode('utf-8') + + dataset_dt = np.dtype(type_str) + + if dataset_dt == np.float32: + idx_float = IndexFloat(handle) + c_ivf_flat.deserialize_file( + deref(handle_), c_filename, idx_float.index) + idx_float.trained = True + idx_float.active_index_type = 'float32' + return idx_float + elif dataset_dt == np.byte: + idx_int8 = IndexInt8(handle) + c_ivf_flat.deserialize_file( + deref(handle_), c_filename, idx_int8.index) + idx_int8.trained = True + idx_int8.active_index_type = 'byte' + return idx_int8 + elif dataset_dt == np.ubyte: + idx_uint8 = IndexUint8(handle) + c_ivf_flat.deserialize_file( + deref(handle_), c_filename, idx_uint8.index) + idx_uint8.trained = True + idx_uint8.active_index_type = 'ubyte' + return idx_uint8 + else: + raise ValueError("Index dtype %s not supported" % dataset_dt) diff --git a/python/pylibraft/pylibraft/test/test_ivf_flat.py b/python/pylibraft/pylibraft/test/test_ivf_flat.py index 593980f7c8..23140073f1 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_flat.py +++ b/python/pylibraft/pylibraft/test/test_ivf_flat.py @@ -461,3 +461,50 @@ def test_search_inputs(params): out_idx_device, out_dist_device, ) + + +@pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte]) +def test_save_load(dtype): + n_rows = 10000 + n_cols = 50 + n_queries = 1000 + + dataset = generate_data((n_rows, n_cols), dtype) + dataset_device = device_ndarray(dataset) + + build_params = ivf_flat.IndexParams(n_lists=100, metric="sqeuclidean") + index = ivf_flat.build(build_params, dataset_device) + + assert index.trained + filename = "my_index.bin" + ivf_flat.save(filename, index) + loaded_index = ivf_flat.load(filename) + + assert index.metric == loaded_index.metric + assert index.n_lists == loaded_index.n_lists + assert index.dim == loaded_index.dim + assert index.adaptive_centers == loaded_index.adaptive_centers + + queries = generate_data((n_queries, n_cols), dtype) + + queries_device = device_ndarray(queries) + search_params = ivf_flat.SearchParams(n_probes=100) + k = 10 + + distance_dev, neighbors_dev = ivf_flat.search( + search_params, index, queries_device, k + ) + + neighbors = neighbors_dev.copy_to_host() + dist = distance_dev.copy_to_host() + del index + + distance_dev, neighbors_dev = ivf_flat.search( + search_params, loaded_index, queries_device, k + ) + + neighbors2 = neighbors_dev.copy_to_host() + dist2 = distance_dev.copy_to_host() + + assert np.all(neighbors == neighbors2) + assert np.allclose(dist, dist2, rtol=1e-6)