Skip to content

Commit

Permalink
Save dtype string by the C++ serializer
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed May 19, 2023
1 parent cbcbadd commit acf589b
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 35 deletions.
10 changes: 9 additions & 1 deletion cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/detail/mdspan_numpy_serializer.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/serialize.hpp>
Expand All @@ -33,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail {
// backward compatibility.
// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward
// compatible fashion.
constexpr int serialization_version = 3;
constexpr int serialization_version = 4;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand Down Expand Up @@ -62,6 +63,10 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T, I
RAFT_LOG_DEBUG(
"Saving IVF-Flat index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
dtype_string.resize(4);
os << dtype_string;

serialize_scalar(handle, os, serialization_version);
serialize_scalar(handle, os, index_.size());
serialize_scalar(handle, os, index_.dim());
Expand Down Expand Up @@ -123,6 +128,9 @@ void serialize(raft::resources const& handle,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& handle, std::istream& is) -> index<T, IdxT>
{
char dtype_string[4];
is.read(dtype_string, 4);

auto ver = deserialize_scalar<int>(handle, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
Expand Down
7 changes: 7 additions & 0 deletions cpp/include/raft_runtime/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,13 @@ namespace raft::runtime::neighbors::ivf_flat {
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::ivf_flat::index<T, IdxT>& index); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::ivf_flat::index<T, IdxT>& index); \
Expand Down
48 changes: 31 additions & 17 deletions cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,37 @@

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
std::stringstream os; \
raft::neighbors::ivf_flat::serialize(handle, os, index); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, is); \
#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
raft::neighbors::ivf_flat::serialize(handle, filename, index); \
}; \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, filename); \
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
std::stringstream os; \
raft::neighbors::ivf_flat::serialize(handle, os, index); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, is); \
}

RAFT_IVF_FLAT_SERIALIZE_INST(float);
Expand Down
24 changes: 24 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,27 @@ cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \
cdef void deserialize(const device_resources& handle,
const string& str,
index[int8_t, int64_t]* index) except +

cdef void serialize_file(const device_resources& handle,
const string& filename,
const index[float, int64_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[float, int64_t]* index) except +

cdef void serialize_file(const device_resources& handle,
const string& filename,
const index[uint8_t, int64_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[uint8_t, int64_t]* index) except +

cdef void serialize_file(const device_resources& handle,
const string& filename,
const index[int8_t, int64_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[int8_t, int64_t]* index) except +
30 changes: 13 additions & 17 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -751,33 +751,28 @@ def save(filename, Index index, handle=None):
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

cdef string c_string
cdef string c_filename = filename.encode('utf-8')

cdef IndexFloat idx_float
cdef IndexInt8 idx_int8
cdef IndexUint8 idx_uint8

if index.active_index_type == "float32":
idx_float = index
c_ivf_flat.serialize(
deref(handle_), c_string, deref(idx_float.index))
c_ivf_flat.serialize_file(
deref(handle_), c_filename, deref(idx_float.index))
elif index.active_index_type == "byte":
idx_int8 = index
c_ivf_flat.serialize(
deref(handle_), c_string, deref(idx_int8.index))
c_ivf_flat.serialize_file(
deref(handle_), c_filename, deref(idx_int8.index))
elif index.active_index_type == "ubyte":
idx_uint8 = index
c_ivf_flat.serialize(
deref(handle_), c_string, deref(idx_uint8.index))
c_ivf_flat.serialize_file(
deref(handle_), c_filename, deref(idx_uint8.index))
else:
raise ValueError(
"Index dtype %s not supported" % index.active_index_type)

dtype = np.dtype(index.active_index_type)
with open(filename, 'wb') as f:
f.write(bytes(dtype.str, 'utf-8'))
f.write(c_string)


@auto_sync_handle
def load(filename, handle=None):
Expand Down Expand Up @@ -837,26 +832,27 @@ def load(filename, handle=None):

with open(filename, 'rb') as f:
type_str = f.read(3).decode('utf-8')
serialized_index = f.read()

dataset_dt = np.dtype(type_str)
cdef string c_idx_str = serialized_index

if dataset_dt == np.float32:
idx_float = IndexFloat(handle)
c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_float.index)
c_ivf_flat.deserialize_file(
deref(handle_), c_filename, idx_float.index)
idx_float.trained = True
idx_float.active_index_type = 'float32'
return idx_float
elif dataset_dt == np.byte:
idx_int8 = IndexInt8(handle)
c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_int8.index)
c_ivf_flat.deserialize_file(
deref(handle_), c_filename, idx_int8.index)
idx_int8.trained = True
idx_int8.active_index_type = 'byte'
return idx_int8
elif dataset_dt == np.ubyte:
idx_uint8 = IndexUint8(handle)
c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_uint8.index)
c_ivf_flat.deserialize_file(
deref(handle_), c_filename, idx_uint8.index)
idx_uint8.trained = True
idx_uint8.active_index_type = 'ubyte'
return idx_uint8
Expand Down

0 comments on commit acf589b

Please sign in to comment.