diff --git a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh index af2e6ba0f8..b00d308827 100644 --- a/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include #include #include @@ -33,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail { // backward compatibility. // TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward // compatible fashion. -constexpr int serialization_version = 3; +constexpr int serialization_version = 4; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -62,6 +63,10 @@ void serialize(raft::resources const& handle, std::ostream& os, const index(index_.size()), index_.dim()); + std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + serialize_scalar(handle, os, serialization_version); serialize_scalar(handle, os, index_.size()); serialize_scalar(handle, os, index_.dim()); @@ -123,6 +128,9 @@ void serialize(raft::resources const& handle, template auto deserialize(raft::resources const& handle, std::istream& is) -> index { + char dtype_string[4]; + is.read(dtype_string, 4); + auto ver = deserialize_scalar(handle, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); diff --git a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp index bc3965f585..5b8918ec7f 100644 --- a/cpp/include/raft_runtime/neighbors/ivf_flat.hpp +++ b/cpp/include/raft_runtime/neighbors/ivf_flat.hpp @@ -46,6 +46,13 @@ namespace raft::runtime::neighbors::ivf_flat { std::optional> new_indices, \ raft::neighbors::ivf_flat::index* idx); \ \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::ivf_flat::index& index); \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::ivf_flat::index* index); \ void serialize(raft::resources const& handle, \ std::string& str, \ const raft::neighbors::ivf_flat::index& index); \ diff --git a/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu b/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu index b8548cc867..049b8b00da 100644 --- a/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu @@ -24,23 +24,37 @@ namespace raft::runtime::neighbors::ivf_flat { -#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \ - void serialize(raft::resources const& handle, \ - std::string& str, \ - const raft::neighbors::ivf_flat::index& index) \ - { \ - std::stringstream os; \ - raft::neighbors::ivf_flat::serialize(handle, os, index); \ - str = os.str(); \ - } \ - \ - void deserialize(raft::resources const& handle, \ - const std::string& str, \ - raft::neighbors::ivf_flat::index* index) \ - { \ - std::istringstream is(str); \ - if (!index) { RAFT_FAIL("Invalid index pointer"); } \ - *index = raft::neighbors::ivf_flat::deserialize(handle, is); \ +#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \ + void serialize_file(raft::resources const& handle, \ + const std::string& filename, \ + const raft::neighbors::ivf_flat::index& index) \ + { \ + raft::neighbors::ivf_flat::serialize(handle, filename, index); \ + }; \ + \ + void deserialize_file(raft::resources const& handle, \ + const std::string& filename, \ + raft::neighbors::ivf_flat::index* index) \ + { \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::ivf_flat::deserialize(handle, filename); \ + }; \ + void serialize(raft::resources const& handle, \ + std::string& str, \ + const raft::neighbors::ivf_flat::index& index) \ + { \ + std::stringstream os; \ + raft::neighbors::ivf_flat::serialize(handle, os, index); \ + str = os.str(); \ + } \ + \ + void deserialize(raft::resources const& handle, \ + const std::string& str, \ + raft::neighbors::ivf_flat::index* index) \ + { \ + std::istringstream is(str); \ + if (!index) { RAFT_FAIL("Invalid index pointer"); } \ + *index = raft::neighbors::ivf_flat::deserialize(handle, is); \ } RAFT_IVF_FLAT_SERIALIZE_INST(float); diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd index 45d854e600..a281d33310 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd @@ -157,3 +157,27 @@ cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \ cdef void deserialize(const device_resources& handle, const string& str, index[int8_t, int64_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[float, int64_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[float, int64_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[uint8_t, int64_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[uint8_t, int64_t]* index) except + + + cdef void serialize_file(const device_resources& handle, + const string& filename, + const index[int8_t, int64_t]& index) except + + + cdef void deserialize_file(const device_resources& handle, + const string& filename, + index[int8_t, int64_t]* index) except + diff --git a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx index ba8f66a71a..0e550547d3 100644 --- a/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx +++ b/python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx @@ -751,7 +751,7 @@ def save(filename, Index index, handle=None): cdef device_resources* handle_ = \ handle.getHandle() - cdef string c_string + cdef string c_filename = filename.encode('utf-8') cdef IndexFloat idx_float cdef IndexInt8 idx_int8 @@ -759,25 +759,20 @@ def save(filename, Index index, handle=None): if index.active_index_type == "float32": idx_float = index - c_ivf_flat.serialize( - deref(handle_), c_string, deref(idx_float.index)) + c_ivf_flat.serialize_file( + deref(handle_), c_filename, deref(idx_float.index)) elif index.active_index_type == "byte": idx_int8 = index - c_ivf_flat.serialize( - deref(handle_), c_string, deref(idx_int8.index)) + c_ivf_flat.serialize_file( + deref(handle_), c_filename, deref(idx_int8.index)) elif index.active_index_type == "ubyte": idx_uint8 = index - c_ivf_flat.serialize( - deref(handle_), c_string, deref(idx_uint8.index)) + c_ivf_flat.serialize_file( + deref(handle_), c_filename, deref(idx_uint8.index)) else: raise ValueError( "Index dtype %s not supported" % index.active_index_type) - dtype = np.dtype(index.active_index_type) - with open(filename, 'wb') as f: - f.write(bytes(dtype.str, 'utf-8')) - f.write(c_string) - @auto_sync_handle def load(filename, handle=None): @@ -837,26 +832,27 @@ def load(filename, handle=None): with open(filename, 'rb') as f: type_str = f.read(3).decode('utf-8') - serialized_index = f.read() dataset_dt = np.dtype(type_str) - cdef string c_idx_str = serialized_index if dataset_dt == np.float32: idx_float = IndexFloat(handle) - c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_float.index) + c_ivf_flat.deserialize_file( + deref(handle_), c_filename, idx_float.index) idx_float.trained = True idx_float.active_index_type = 'float32' return idx_float elif dataset_dt == np.byte: idx_int8 = IndexInt8(handle) - c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_int8.index) + c_ivf_flat.deserialize_file( + deref(handle_), c_filename, idx_int8.index) idx_int8.trained = True idx_int8.active_index_type = 'byte' return idx_int8 elif dataset_dt == np.ubyte: idx_uint8 = IndexUint8(handle) - c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_uint8.index) + c_ivf_flat.deserialize_file( + deref(handle_), c_filename, idx_uint8.index) idx_uint8.trained = True idx_uint8.active_index_type = 'ubyte' return idx_uint8