Skip to content

Commit

Permalink
Python API for IVF-Flat serialization (#1516)
Browse files Browse the repository at this point in the history
This PR adds Python API for IVF-Flat serialization.

closes #752

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1516
  • Loading branch information
tfeher authored May 19, 2023
1 parent dfb3d2c commit af7e067
Show file tree
Hide file tree
Showing 8 changed files with 348 additions and 3 deletions.
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -365,6 +365,7 @@ if(RAFT_COMPILE_LIBRARY)
src/raft_runtime/neighbors/brute_force_knn_int64_t_float.cu
src/raft_runtime/neighbors/ivf_flat_build.cu
src/raft_runtime/neighbors/ivf_flat_search.cu
src/raft_runtime/neighbors/ivf_flat_serialize.cu
src/raft_runtime/neighbors/ivfpq_build.cu
src/raft_runtime/neighbors/ivfpq_deserialize.cu
src/raft_runtime/neighbors/ivfpq_search_float_int64_t.cu
Expand Down
10 changes: 9 additions & 1 deletion cpp/include/raft/neighbors/detail/ivf_flat_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/detail/mdspan_numpy_serializer.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/serialize.hpp>
Expand All @@ -33,7 +34,7 @@ namespace raft::neighbors::ivf_flat::detail {
// backward compatibility.
// TODO(hcho3) Implement next-gen serializer for IVF that allows for expansion in a backward
// compatible fashion.
constexpr int serialization_version = 3;
constexpr int serialization_version = 4;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand Down Expand Up @@ -62,6 +63,10 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T, I
RAFT_LOG_DEBUG(
"Saving IVF-Flat index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
dtype_string.resize(4);
os << dtype_string;

serialize_scalar(handle, os, serialization_version);
serialize_scalar(handle, os, index_.size());
serialize_scalar(handle, os, index_.dim());
Expand Down Expand Up @@ -123,6 +128,9 @@ void serialize(raft::resources const& handle,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& handle, std::istream& is) -> index<T, IdxT>
{
char dtype_string[4];
is.read(dtype_string, 4);

auto ver = deserialize_scalar<int>(handle, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
Expand Down
17 changes: 16 additions & 1 deletion cpp/include/raft_runtime/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/neighbors/ivf_flat_types.hpp>
#include <string>

namespace raft::runtime::neighbors::ivf_flat {

Expand All @@ -43,7 +44,21 @@ namespace raft::runtime::neighbors::ivf_flat {
void extend(raft::resources const& handle, \
raft::device_matrix_view<const T, IdxT, row_major> new_vectors, \
std::optional<raft::device_vector_view<const IdxT, IdxT>> new_indices, \
raft::neighbors::ivf_flat::index<T, IdxT>* idx);
raft::neighbors::ivf_flat::index<T, IdxT>* idx); \
\
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::ivf_flat::index<T, IdxT>& index); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::ivf_flat::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::ivf_flat::index<T, IdxT>& index); \
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::ivf_flat::index<T, IdxT>*);

RAFT_INST_BUILD_EXTEND(float, int64_t)
RAFT_INST_BUILD_EXTEND(int8_t, int64_t)
Expand Down
65 changes: 65 additions & 0 deletions cpp/src/raft_runtime/neighbors/ivf_flat_serialize.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <sstream>
#include <string>

#include <raft/core/device_resources.hpp>
#include <raft/neighbors/ivf_flat_serialize.cuh>
#include <raft/neighbors/ivf_flat_types.hpp>
#include <raft_runtime/neighbors/ivf_flat.hpp>

namespace raft::runtime::neighbors::ivf_flat {

#define RAFT_IVF_FLAT_SERIALIZE_INST(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
raft::neighbors::ivf_flat::serialize(handle, filename, index); \
}; \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, filename); \
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::ivf_flat::index<DTYPE, int64_t>& index) \
{ \
std::stringstream os; \
raft::neighbors::ivf_flat::serialize(handle, os, index); \
str = os.str(); \
} \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
raft::neighbors::ivf_flat::index<DTYPE, int64_t>* index) \
{ \
std::istringstream is(str); \
if (!index) { RAFT_FAIL("Invalid index pointer"); } \
*index = raft::neighbors::ivf_flat::deserialize<DTYPE, int64_t>(handle, is); \
}

RAFT_IVF_FLAT_SERIALIZE_INST(float);
RAFT_IVF_FLAT_SERIALIZE_INST(int8_t);
RAFT_IVF_FLAT_SERIALIZE_INST(uint8_t);

#undef RAFT_IVF_FLAT_SERIALIZE_INST
} // namespace raft::runtime::neighbors::ivf_flat
13 changes: 12 additions & 1 deletion python/pylibraft/pylibraft/neighbors/ivf_flat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@
# limitations under the License.
#

from .ivf_flat import Index, IndexParams, SearchParams, build, extend, search
from .ivf_flat import (
Index,
IndexParams,
SearchParams,
build,
extend,
load,
save,
search,
)

__all__ = [
"Index",
Expand All @@ -22,4 +31,6 @@
"build",
"extend",
"search",
"save",
"load",
]
48 changes: 48 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/cpp/c_ivf_flat.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -133,3 +133,51 @@ cdef extern from "raft_runtime/neighbors/ivf_flat.hpp" \
device_matrix_view[uint8_t, int64_t, row_major] queries,
device_matrix_view[int64_t, int64_t, row_major] neighbors,
device_matrix_view[float, int64_t, row_major] distances) except +

cdef void serialize(const device_resources& handle,
string& str,
const index[float, int64_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& str,
index[float, int64_t]* index) except +

cdef void serialize(const device_resources& handle,
string& str,
const index[uint8_t, int64_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& str,
index[uint8_t, int64_t]* index) except +

cdef void serialize(const device_resources& handle,
string& str,
const index[int8_t, int64_t]& index) except +

cdef void deserialize(const device_resources& handle,
const string& str,
index[int8_t, int64_t]* index) except +

cdef void serialize_file(const device_resources& handle,
const string& filename,
const index[float, int64_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[float, int64_t]* index) except +

cdef void serialize_file(const device_resources& handle,
const string& filename,
const index[uint8_t, int64_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[uint8_t, int64_t]* index) except +

cdef void serialize_file(const device_resources& handle,
const string& filename,
const index[int8_t, int64_t]& index) except +

cdef void deserialize_file(const device_resources& handle,
const string& filename,
index[int8_t, int64_t]* index) except +
150 changes: 150 additions & 0 deletions python/pylibraft/pylibraft/neighbors/ivf_flat/ivf_flat.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -708,3 +708,153 @@ def search(SearchParams search_params,
raise ValueError("query dtype %s not supported" % queries_dt)

return (distances, neighbors)


@auto_sync_handle
def save(filename, Index index, handle=None):
"""
Saves the index to file.
Saving / loading the index is experimental. The serialization format is
subject to change.
Parameters
----------
filename : string
Name of the file.
index : Index
Trained IVF-Flat index.
{handle_docstring}
Examples
--------
>>> import cupy as cp
>>> from pylibraft.common import DeviceResources
>>> from pylibraft.neighbors import ivf_flat
>>> n_samples = 50000
>>> n_features = 50
>>> dataset = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
>>> # Build index
>>> handle = DeviceResources()
>>> index = ivf_flat.build(ivf_flat.IndexParams(), dataset, handle=handle)
>>> ivf_flat.save("my_index.bin", index, handle=handle)
"""
if not index.trained:
raise ValueError("Index need to be built before saving it.")

if handle is None:
handle = DeviceResources()
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

cdef string c_filename = filename.encode('utf-8')

cdef IndexFloat idx_float
cdef IndexInt8 idx_int8
cdef IndexUint8 idx_uint8

if index.active_index_type == "float32":
idx_float = index
c_ivf_flat.serialize_file(
deref(handle_), c_filename, deref(idx_float.index))
elif index.active_index_type == "byte":
idx_int8 = index
c_ivf_flat.serialize_file(
deref(handle_), c_filename, deref(idx_int8.index))
elif index.active_index_type == "ubyte":
idx_uint8 = index
c_ivf_flat.serialize_file(
deref(handle_), c_filename, deref(idx_uint8.index))
else:
raise ValueError(
"Index dtype %s not supported" % index.active_index_type)


@auto_sync_handle
def load(filename, handle=None):
"""
Loads index from file.
Saving / loading the index is experimental. The serialization format is
subject to change, therefore loading an index saved with a previous
version of raft is not guaranteed to work.
Parameters
----------
filename : string
Name of the file.
{handle_docstring}
Returns
-------
index : Index
Examples
--------
>>> import cupy as cp
>>> from pylibraft.common import DeviceResources
>>> from pylibraft.neighbors import ivf_flat
>>> n_samples = 50000
>>> n_features = 50
>>> dataset = cp.random.random_sample((n_samples, n_features),
... dtype=cp.float32)
>>> # Build and save index
>>> handle = DeviceResources()
>>> index = ivf_flat.build(ivf_flat.IndexParams(), dataset, handle=handle)
>>> ivf_flat.save("my_index.bin", index, handle=handle)
>>> del index
>>> n_queries = 100
>>> queries = cp.random.random_sample((n_queries, n_features),
... dtype=cp.float32)
>>> handle = DeviceResources()
>>> index = ivf_flat.load("my_index.bin", handle=handle)
>>> distances, neighbors = ivf_flat.search(ivf_pq.SearchParams(), index,
... queries, k=10, handle=handle)
"""
if handle is None:
handle = DeviceResources()
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

cdef string c_filename = filename.encode('utf-8')
cdef IndexFloat idx_float
cdef IndexInt8 idx_int8
cdef IndexUint8 idx_uint8

with open(filename, 'rb') as f:
type_str = f.read(3).decode('utf-8')

dataset_dt = np.dtype(type_str)

if dataset_dt == np.float32:
idx_float = IndexFloat(handle)
c_ivf_flat.deserialize_file(
deref(handle_), c_filename, idx_float.index)
idx_float.trained = True
idx_float.active_index_type = 'float32'
return idx_float
elif dataset_dt == np.byte:
idx_int8 = IndexInt8(handle)
c_ivf_flat.deserialize_file(
deref(handle_), c_filename, idx_int8.index)
idx_int8.trained = True
idx_int8.active_index_type = 'byte'
return idx_int8
elif dataset_dt == np.ubyte:
idx_uint8 = IndexUint8(handle)
c_ivf_flat.deserialize_file(
deref(handle_), c_filename, idx_uint8.index)
idx_uint8.trained = True
idx_uint8.active_index_type = 'ubyte'
return idx_uint8
else:
raise ValueError("Index dtype %s not supported" % dataset_dt)
Loading

0 comments on commit af7e067

Please sign in to comment.