diff --git a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp index 68ec3ab278..bc3965f585 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp @@ -17,6 +17,7 @@ #pragma once #include +#include namespace raft::runtime::neighbors::ivf_flat { @@ -46,12 +47,11 @@ namespace raft::runtime::neighbors::ivf_flat { raft::neighbors::ivf_flat::index* idx); \ \ void serialize(raft::resources const& handle, \ - const std::string& filename, \ + std::string& str, \ const raft::neighbors::ivf_flat::index& index); \ - \ void deserialize(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::ivf_flat::index* index); + const std::string& str, \ + raft::neighbors::ivf_flat::index*); RAFT_INST_BUILD_EXTEND(float, int64_t) RAFT_INST_BUILD_EXTEND(int8_t, int64_t) diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu index 6278bf0066..b8548cc867 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu @@ -14,6 +14,7 @@ * limitations under the License. */ +#include #include #include @@ -23,21 +24,24 @@ namespace raft::runtime::neighbors::ivf_flat { -#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \ - void serialize(raft::resources const& handle, \ - const std::string& filename, \ - const raft::neighbors::ivf_flat::index& index) \ - { \ - raft::neighbors::ivf_flat::serialize(handle, filename, index); \ - }; \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& filename, \ - raft::neighbors::ivf_flat::index* index) \ - { \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::ivf_flat::deserialize(handle, filename); \ - }; +#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::ivf_flat::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::ivf_flat::serialize(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::ivf_flat::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::ivf_flat::deserialize(handle, is); \ + } RAFT_IVF_FLAT_SERIALIZE_INST(float); RAFT_IVF_FLAT_SERIALIZE_INST(int8_t); diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd index 66ce7f7e64..45d854e600 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -135,25 +135,25 @@ cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \ device_matrix_view[float, int64_t, row_major] distances) except + cdef void serialize(const device_resources& handle, - const string& filename, + string& str, const index[float, int64_t]& index) except + cdef void deserialize(const device_resources& handle, - const string& filename, + const string& str, index[float, int64_t]* index) except + cdef void serialize(const device_resources& handle, - const string& filename, + string& str, const index[uint8_t, int64_t]& index) except + cdef void deserialize(const device_resources& handle, - const string& filename, + const string& str, index[uint8_t, int64_t]* index) except + cdef void serialize(const device_resources& handle, - const string& filename, + string& str, const index[int8_t, int64_t]& index) except + cdef void deserialize(const device_resources& handle, - const string& filename, + const string& str, index[int8_t, int64_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index 6085c2f540..ba8f66a71a 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -751,7 +751,7 @@ def save(filename, Index index, handle=None): cdef device_resources* handle_ = \ handle.getHandle() - cdef string c_filename = filename.encode('utf-8') + cdef string c_string cdef IndexFloat idx_float cdef IndexInt8 idx_int8 @@ -760,22 +760,27 @@ def save(filename, Index index, handle=None): if index.active_index_type == "float32": idx_float = index c_ivf_flat.serialize( - deref(handle_), c_filename, deref(idx_float.index)) + deref(handle_), c_string, deref(idx_float.index)) elif index.active_index_type == "byte": idx_int8 = index c_ivf_flat.serialize( - deref(handle_), c_filename, deref(idx_int8.index)) + deref(handle_), c_string, deref(idx_int8.index)) elif index.active_index_type == "ubyte": idx_uint8 = index c_ivf_flat.serialize( - deref(handle_), c_filename, deref(idx_uint8.index)) + deref(handle_), c_string, deref(idx_uint8.index)) else: raise ValueError( "Index dtype %s not supported" % index.active_index_type) + dtype = np.dtype(index.active_index_type) + with open(filename, 'wb') as f: + f.write(bytes(dtype.str, 'utf-8')) + f.write(c_string) + @auto_sync_handle -def load(filename, dtype, handle=None): +def load(filename, handle=None): """ Loads index from file. @@ -787,8 +792,6 @@ def load(filename, dtype, handle=None): ---------- filename : string Name of the file. - dtype : data type object - dataset type, supported values [np.float32, np.byte, np.ubyte] {handle_docstring} Returns @@ -817,7 +820,7 @@ def load(filename, dtype, handle=None): >>> queries = cp.random.random_sample((n_queries, n_features), ... dtype=cp.float32) >>> handle = DeviceResources() - >>> index = ivf_flat.load("my_index.bin", dtype=cp.float32, handle=handle) + >>> index = ivf_flat.load("my_index.bin", handle=handle) >>> distances, neighbors = ivf_flat.search(ivf_pq.SearchParams(), index, ... queries, k=10, handle=handle) @@ -832,24 +835,30 @@ def load(filename, dtype, handle=None): cdef IndexInt8 idx_int8 cdef IndexUint8 idx_uint8 - dataset_dt = np.dtype(dtype) + with open(filename, 'rb') as f: + type_str = f.read(3).decode('utf-8') + serialized_index = f.read() + + dataset_dt = np.dtype(type_str) + cdef string c_idx_str = serialized_index + if dataset_dt == np.float32: idx_float = IndexFloat(handle) - c_ivf_flat.deserialize(deref(handle_), c_filename, idx_float.index) + c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_float.index) idx_float.trained = True idx_float.active_index_type = 'float32' return idx_float elif dataset_dt == np.byte: idx_int8 = IndexInt8(handle) - c_ivf_flat.deserialize(deref(handle_), c_filename, idx_int8.index) + c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_int8.index) idx_int8.trained = True idx_int8.active_index_type = 'byte' return idx_int8 elif dataset_dt == np.ubyte: idx_uint8 = IndexUint8(handle) - c_ivf_flat.deserialize(deref(handle_), c_filename, idx_uint8.index) + c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_uint8.index) idx_uint8.trained = True idx_uint8.active_index_type = 'ubyte' return idx_uint8 else: - raise ValueError("Index dtype %s not supported" % dtype) + raise ValueError("Index dtype %s not supported" % dataset_dt) diff --git a/python/pylibraft/pylibraft/test/test_ivf_flat.py b/python/pylibraft/pylibraft/test/test_ivf_flat.py index 46a421d3e8..23140073f1 100644 --- a/python/pylibraft/pylibraft/test/test_ivf_flat.py +++ b/python/pylibraft/pylibraft/test/test_ivf_flat.py @@ -478,7 +478,7 @@ def test_save_load(dtype): assert index.trained filename = "my_index.bin" ivf_flat.save(filename, index) - loaded_index = ivf_flat.load(filename, dtype) + loaded_index = ivf_flat.load(filename) assert index.metric == loaded_index.metric assert index.n_lists == loaded_index.n_lists