Skip to content

Commit

Permalink
Update CAGRA serialization
Browse files Browse the repository at this point in the history
This changes the serialization format of saved CAGRA instances by:

* The dtype will now be written in the first 4 bytes of the index, to match
the IVF methods and to make it easier to deduce the dtype from python (rapidsai#1729)
* Writing out the dataset with the index is now optional. Since many use cases
will already have the dataset written out separately, this gives us the
option to save disk space by not writing out an extra copy of the input dataset.
If the include_dataset=false option is given, you will have to call `index.update_dataset`
to set the dataset yourself after loading
  • Loading branch information
benfred committed Aug 18, 2023
1 parent 7aded12 commit bfa22a7
Show file tree
Hide file tree
Showing 8 changed files with 182 additions and 48 deletions.
16 changes: 11 additions & 5 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ namespace raft::neighbors::cagra {
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& handle, std::ostream& os, const index<T, IdxT>& index)
void serialize(raft::resources const& handle,
std::ostream& os,
const index<T, IdxT>& index,
bool include_dataset = true)
{
detail::serialize(handle, os, index);
detail::serialize(handle, os, index, include_dataset);
}

/**
Expand All @@ -77,14 +81,16 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T, I
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
const index<T, IdxT>& index,
bool include_dataset = true)
{
detail::serialize(handle, filename, index);
detail::serialize(handle, filename, index, include_dataset);
}

/**
Expand Down Expand Up @@ -158,4 +164,4 @@ namespace raft::neighbors::experimental::cagra {
using raft::neighbors::cagra::deserialize;
using raft::neighbors::cagra::serialize;

} // namespace raft::neighbors::experimental::cagra
} // namespace raft::neighbors::experimental::cagra
15 changes: 15 additions & 0 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ struct index : ann::index {
dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1));
}
}
void update_dataset(raft::resources const& res,
raft::device_matrix_view<T, int64_t, row_major> dataset)
{
update_dataset(res,
make_device_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1)));
}

/**
* Replace the dataset with a new dataset.
Expand All @@ -271,6 +278,14 @@ struct index : ann::index {
copy_padded(res, dataset);
}

void update_dataset(raft::resources const& res,
raft::host_matrix_view<T, int64_t, row_major> dataset)
{
update_dataset(res,
make_host_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1)));
}

/**
* Replace the graph with a new graph.
*
Expand Down
54 changes: 35 additions & 19 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

namespace raft::neighbors::cagra::detail {

// Serialization version 1.
constexpr int serialization_version = 2;
constexpr int serialization_version = 3;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand All @@ -50,41 +49,53 @@ template struct check_index_layout<sizeof(index<double, std::uint64_t>), expecte
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& res, std::ostream& os, const index<T, IdxT>& index_)
void serialize(raft::resources const& res,
std::ostream& os,
const index<T, IdxT>& index_,
bool include_dataset)
{
RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
dtype_string.resize(4);
os << dtype_string;

serialize_scalar(res, os, serialization_version);
serialize_scalar(res, os, index_.size());
serialize_scalar(res, os, index_.dim());
serialize_scalar(res, os, index_.graph_degree());
serialize_scalar(res, os, index_.metric());
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
serialize_mdspan(res, os, index_.graph());

serialize_scalar(res, os, include_dataset);
if (include_dataset) {
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
}
}

template <typename T, typename IdxT>
void serialize(raft::resources const& res,
const std::string& filename,
const index<T, IdxT>& index_)
const index<T, IdxT>& index_,
bool include_dataset)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

detail::serialize(res, of, index_);
detail::serialize(res, of, index_, include_dataset);

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
Expand All @@ -102,6 +113,9 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
char dtype_string[4];
is.read(dtype_string, 4);

auto ver = deserialize_scalar<int>(res, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
Expand All @@ -113,9 +127,11 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, dataset.view());
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
}
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft_runtime/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@ namespace raft::runtime::neighbors::cagra {
raft::device_matrix_view<float, int64_t, row_major> distances); \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/raft_runtime/neighbors/cagra_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ namespace raft::runtime::neighbors::cagra {
#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
raft::neighbors::cagra::serialize(handle, filename, index); \
raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \
}; \
\
void deserialize_file(raft::resources const& handle, \
Expand All @@ -41,10 +42,11 @@ namespace raft::runtime::neighbors::cagra {
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize(handle, os, index); \
raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \
str = os.str(); \
} \
\
Expand Down
96 changes: 86 additions & 10 deletions python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,31 @@ cdef class IndexFloat(Index):
attr_str = [m_str] + attr_str
return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")"

@auto_sync_handle
def update_dataset(self, dataset, handle=None):
""" Replace the dataset with a new dataset.
Parameters
----------
dataset : array interface compliant matrix shape (n_samples, dim)
{handle_docstring}
"""
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

dataset_ai = wrap_array(dataset)
dataset_dt = dataset_ai.dtype
_check_input_array(dataset_ai, [np.dtype("float32")])

if dataset_ai.from_cai:
self.index[0].update_dataset(deref(handle_),
get_dmv_float(dataset_ai,
check_shape=True))
else:
self.index[0].update_dataset(deref(handle_),
get_hmv_float(dataset_ai,
check_shape=True))

@property
def metric(self):
return self.index[0].metric()
Expand Down Expand Up @@ -195,6 +220,31 @@ cdef class IndexInt8(Index):
self.index = new c_cagra.index[int8_t, uint32_t](
deref(handle_))

@auto_sync_handle
def update_dataset(self, dataset, handle=None):
""" Replace the dataset with a new dataset.
Parameters
----------
dataset : array interface compliant matrix shape (n_samples, dim)
{handle_docstring}
"""
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

dataset_ai = wrap_array(dataset)
dataset_dt = dataset_ai.dtype
_check_input_array(dataset_ai, [np.dtype("byte")])

if dataset_ai.from_cai:
self.index[0].update_dataset(deref(handle_),
get_dmv_int8(dataset_ai,
check_shape=True))
else:
self.index[0].update_dataset(deref(handle_),
get_hmv_int8(dataset_ai,
check_shape=True))

def __repr__(self):
m_str = "metric=" + _get_metric_string(self.index.metric())
attr_str = [attr + "=" + str(getattr(self, attr))
Expand Down Expand Up @@ -235,6 +285,31 @@ cdef class IndexUint8(Index):
self.index = new c_cagra.index[uint8_t, uint32_t](
deref(handle_))

@auto_sync_handle
def update_dataset(self, dataset, handle=None):
""" Replace the dataset with a new dataset.
Parameters
----------
dataset : array interface compliant matrix shape (n_samples, dim)
{handle_docstring}
"""
cdef device_resources* handle_ = \
<device_resources*><size_t>handle.getHandle()

dataset_ai = wrap_array(dataset)
dataset_dt = dataset_ai.dtype
_check_input_array(dataset_ai, [np.dtype("ubyte")])

if dataset_ai.from_cai:
self.index[0].update_dataset(deref(handle_),
get_dmv_uint8(dataset_ai,
check_shape=True))
else:
self.index[0].update_dataset(deref(handle_),
get_hmv_uint8(dataset_ai,
check_shape=True))

def __repr__(self):
m_str = "metric=" + _get_metric_string(self.index.metric())
attr_str = [attr + "=" + str(getattr(self, attr))
Expand Down Expand Up @@ -693,7 +768,7 @@ def search(SearchParams search_params,


@auto_sync_handle
def save(filename, Index index, handle=None):
def save(filename, Index index, bool include_dataset=True, handle=None):
"""
Saves the index to a file.
Expand All @@ -706,6 +781,8 @@ def save(filename, Index index, handle=None):
Name of the file.
index : Index
Trained CAGRA index.
include_dataset : bool
Whether or not to write out the dataset along with the index
{handle_docstring}
Examples
Expand Down Expand Up @@ -741,15 +818,17 @@ def save(filename, Index index, handle=None):
if index.active_index_type == "float32":
idx_float = index
c_cagra.serialize_file(
deref(handle_), c_filename, deref(idx_float.index))
deref(handle_), c_filename, deref(idx_float.index),
include_dataset)
elif index.active_index_type == "byte":
idx_int8 = index
c_cagra.serialize_file(
deref(handle_), c_filename, deref(idx_int8.index))
deref(handle_), c_filename, deref(idx_int8.index), include_dataset)
elif index.active_index_type == "ubyte":
idx_uint8 = index
c_cagra.serialize_file(
deref(handle_), c_filename, deref(idx_uint8.index))
deref(handle_), c_filename, deref(idx_uint8.index),
include_dataset)
else:
raise ValueError(
"Index dtype %s not supported" % index.active_index_type)
Expand Down Expand Up @@ -785,12 +864,9 @@ def load(filename, handle=None):
cdef IndexInt8 idx_int8
cdef IndexUint8 idx_uint8

# we extract the dtype from the array interfaces in the file
with open(filename, 'rb') as f:
type_str = f.read(700).decode("utf-8", errors='ignore')

# Read description of the 6th element to get the datatype
dataset_dt = np.dtype(type_str.split('descr')[6][5:7])
with open(filename, "rb") as f:
type_str = f.read(3).decode("utf8")
dataset_dt = np.dtype(type_str)

if dataset_dt == np.float32:
idx_float = IndexFloat(handle)
Expand Down
Loading

0 comments on commit bfa22a7

Please sign in to comment.