Skip to content

Commit

Permalink
Save dtype before index string
Browse files Browse the repository at this point in the history
  • Loading branch information
tfeher committed May 18, 2023
1 parent 76a8920 commit cbcbadd
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 39 deletions.
8 changes: 4 additions & 4 deletions cpp/include/raft_runtime/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/neighbors/ivf_flat_types.hpp>
#include <string>

namespace raft::runtime::neighbors::ivf_flat {

Expand Down Expand Up @@ -46,12 +47,11 @@ namespace raft::runtime::neighbors::ivf_flat {
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
void serialize(raft::resources const& handle, \
const std::string& filename, \
std::string& str, \
const raft::neighbors::ivf_flat::index<T, IdxT>& index); \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::ivf_flat::index<T, IdxT>* index);
const std::string& str, \
raft::neighbors::ivf_flat::index<T, IdxT>*);

RAFT_INST_BUILD_EXTEND(float, int64_t)
RAFT_INST_BUILD_EXTEND(int8_t, int64_t)
Expand Down
34 changes: 19 additions & 15 deletions cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <sstream>
#include <string>

#include <raft/core/device_resources.hpp>
Expand All @@ -23,21 +24,24 @@

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \
void serialize(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
raft::neighbors::ivf_flat::serialize(handle, filename, index); \
}; \
\
void deserialize(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, filename); \
};
#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
std::stringstream os; \
raft::neighbors::ivf_flat::serialize(handle, os, index); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, is); \
}

RAFT_IVF_FLAT_SERIALIZE_INST(float);
RAFT_IVF_FLAT_SERIALIZE_INST(int8_t);
Expand Down
12 changes: 6 additions & 6 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -135,25 +135,25 @@ cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \
device_matrix_view[float, int64_t, row_major] distances) except +

cdef void serialize(const device_resources& handle,
const string& filename,
string& str,
const index[float, int64_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& filename,
const string& str,
index[float, int64_t]* index) except +

cdef void serialize(const device_resources& handle,
const string& filename,
string& str,
const index[uint8_t, int64_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& filename,
const string& str,
index[uint8_t, int64_t]* index) except +

cdef void serialize(const device_resources& handle,
const string& filename,
string& str,
const index[int8_t, int64_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& filename,
const string& str,
index[int8_t, int64_t]* index) except +
35 changes: 22 additions & 13 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def save(filename, Index index, handle=None):
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

cdef string c_filename = filename.encode('utf-8')
cdef string c_string

cdef IndexFloat idx_float
cdef IndexInt8 idx_int8
Expand All @@ -760,22 +760,27 @@ def save(filename, Index index, handle=None):
if index.active_index_type == "float32":
idx_float = index
c_ivf_flat.serialize(
deref(handle_), c_filename, deref(idx_float.index))
deref(handle_), c_string, deref(idx_float.index))
elif index.active_index_type == "byte":
idx_int8 = index
c_ivf_flat.serialize(
deref(handle_), c_filename, deref(idx_int8.index))
deref(handle_), c_string, deref(idx_int8.index))
elif index.active_index_type == "ubyte":
idx_uint8 = index
c_ivf_flat.serialize(
deref(handle_), c_filename, deref(idx_uint8.index))
deref(handle_), c_string, deref(idx_uint8.index))
else:
raise ValueError(
"Index dtype %s not supported" % index.active_index_type)

dtype = np.dtype(index.active_index_type)
with open(filename, 'wb') as f:
f.write(bytes(dtype.str, 'utf-8'))
f.write(c_string)


@auto_sync_handle
def load(filename, dtype, handle=None):
def load(filename, handle=None):
"""
Loads index from file.
Expand All @@ -787,8 +792,6 @@ def load(filename, dtype, handle=None):
----------
filename : string
Name of the file.
dtype : data type object
dataset type, supported values [np.float32, np.byte, np.ubyte]
{handle_docstring}
Returns
Expand Down Expand Up @@ -817,7 +820,7 @@ def load(filename, dtype, handle=None):
>>> queries = cp.random.random_sample((n_queries, n_features),
... dtype=cp.float32)
>>> handle = DeviceResources()
>>> index = ivf_flat.load("my_index.bin", dtype=cp.float32, handle=handle)
>>> index = ivf_flat.load("my_index.bin", handle=handle)
>>> distances, neighbors = ivf_flat.search(ivf_pq.SearchParams(), index,
... queries, k=10, handle=handle)
Expand All @@ -832,24 +835,30 @@ def load(filename, dtype, handle=None):
cdef IndexInt8 idx_int8
cdef IndexUint8 idx_uint8

dataset_dt = np.dtype(dtype)
with open(filename, 'rb') as f:
type_str = f.read(3).decode('utf-8')
serialized_index = f.read()

dataset_dt = np.dtype(type_str)
cdef string c_idx_str = serialized_index

if dataset_dt == np.float32:
idx_float = IndexFloat(handle)
c_ivf_flat.deserialize(deref(handle_), c_filename, idx_float.index)
c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_float.index)
idx_float.trained = True
idx_float.active_index_type = 'float32'
return idx_float
elif dataset_dt == np.byte:
idx_int8 = IndexInt8(handle)
c_ivf_flat.deserialize(deref(handle_), c_filename, idx_int8.index)
c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_int8.index)
idx_int8.trained = True
idx_int8.active_index_type = 'byte'
return idx_int8
elif dataset_dt == np.ubyte:
idx_uint8 = IndexUint8(handle)
c_ivf_flat.deserialize(deref(handle_), c_filename, idx_uint8.index)
c_ivf_flat.deserialize(deref(handle_), c_idx_str, idx_uint8.index)
idx_uint8.trained = True
idx_uint8.active_index_type = 'ubyte'
return idx_uint8
else:
raise ValueError("Index dtype %s not supported" % dtype)
raise ValueError("Index dtype %s not supported" % dataset_dt)
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/test/test_ivf_flat.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ def test_save_load(dtype):
assert index.trained
filename = "my_index.bin"
ivf_flat.save(filename, index)
loaded_index = ivf_flat.load(filename, dtype)
loaded_index = ivf_flat.load(filename)

assert index.metric == loaded_index.metric
assert index.n_lists == loaded_index.n_lists
Expand Down

0 comments on commit cbcbadd

Please sign in to comment.