Skip to content

Commit

Permalink
Serialization of IVF Flat and IVF PQ (#919)
Browse files Browse the repository at this point in the history
This PR implements serialization to file for  `ivf_pq::index` and `ivf_flat::index` structures.

Index building takes time, therefore downstream projects (like cuML) want to save the index (rapidsai/cuml#4743). But downstream project should not depend on the implementation details of the index, therefore RAFT provides methods to serialize and deserialize the index.

This is still experimental:
- ideally we want to use a general serialization method for mdspan #770,
- instead of directly saving to file, raft should provide a byte string and let the downstream project decide how to save it (e.g. pickle for cuML).

Python wrappers are provided for IVF-PQ to save/load the index.

Authors:
  - Tamas Bela Feher (https://github.com/tfeher)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #919
  • Loading branch information
tfeher authored Jan 3, 2023
1 parent ef95988 commit 96578a1
Show file tree
Hide file tree
Showing 11 changed files with 577 additions and 4 deletions.
140 changes: 140 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ann_serialization.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
/*
* Copyright (c) 2022, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cassert>
#include <fstream>
#include <iostream>
#include <memory>
#include <raft/core/device_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_type.hpp>
#include <raft/util/cudart_utils.hpp>
#include <rmm/device_uvector.hpp>
#include <stdexcept>
#include <string>
#include <type_traits>

namespace raft::spatial::knn::detail {

template <typename T>
void write_scalar(std::ofstream& of, const T& value)
{
of.write((char*)&value, sizeof value);
if (of.good()) {
RAFT_LOG_DEBUG("Written %z bytes", (sizeof value));
} else {
RAFT_FAIL("error writing value to file");
}
}

template <typename T>
T read_scalar(std::ifstream& file)
{
T value;
file.read((char*)&value, sizeof value);
if (file.good()) {
RAFT_LOG_DEBUG("Read %z bytes", (sizeof value));
} else {
RAFT_FAIL("error reading value from file");
}
return value;
}

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void write_mdspan(
const raft::handle_t& handle,
std::ofstream& of,
const raft::device_mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>& obj)
{
using obj_t = raft::device_mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>;
write_scalar(of, obj.rank());
if (obj.is_exhaustive() && obj.is_unique()) {
write_scalar(of, obj.size());
} else {
RAFT_FAIL("Cannot serialize non exhaustive mdarray");
}
if (obj.size() > 0) {
for (typename obj_t::rank_type i = 0; i < obj.rank(); i++)
write_scalar(of, obj.extent(i));
cudaStream_t stream = handle.get_stream();
std::vector<
typename raft::device_mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>::value_type>
tmp(obj.size());
raft::update_host(tmp.data(), obj.data_handle(), obj.size(), stream);
handle.sync_stream(stream);
of.write(reinterpret_cast<char*>(tmp.data()), tmp.size() * sizeof(ElementType));
if (of.good()) {
RAFT_LOG_DEBUG("Written %zu bytes",
static_cast<size_t>(obj.size() * sizeof(obj.data_handle()[0])));
} else {
RAFT_FAIL("Error writing mdarray to file");
}
} else {
RAFT_LOG_DEBUG("Skipping mdspand with zero size");
}
}

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void read_mdspan(const raft::handle_t& handle,
std::ifstream& file,
raft::device_mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>& obj)
{
using obj_t = raft::device_mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>;
auto rank = read_scalar<typename obj_t::rank_type>(file);
if (obj.rank() != rank) { RAFT_FAIL("Incorrect rank while reading mdarray"); }
auto size = read_scalar<typename obj_t::size_type>(file);
if (obj.size() != size) {
RAFT_FAIL("Incorrect rank while reading mdarray %zu vs %zu",
static_cast<size_t>(size),
static_cast<size_t>(obj.size()));
}
if (obj.size() > 0) {
for (typename obj_t::rank_type i = 0; i < obj.rank(); i++) {
auto ex = read_scalar<typename obj_t::index_type>(file);
if (obj.extent(i) != ex) {
RAFT_FAIL("Incorrect extent while reading mdarray %d vs %d at %d",
static_cast<int>(ex),
static_cast<int>(obj.extent(i)),
static_cast<int>(i));
}
}
cudaStream_t stream = handle.get_stream();
std::vector<typename obj_t::value_type> tmp(obj.size());
file.read(reinterpret_cast<char*>(tmp.data()), tmp.size() * sizeof(ElementType));
raft::update_device(obj.data_handle(), tmp.data(), tmp.size(), stream);
handle.sync_stream(stream);
if (file.good()) {
RAFT_LOG_DEBUG("Read %zu bytes",
static_cast<size_t>(obj.size() * sizeof(obj.data_handle()[0])));
} else {
RAFT_FAIL("error reading mdarray from file");
}
} else {
RAFT_LOG_DEBUG("Skipping mdspand with zero size");
}
}

template <typename ElementType, typename Extents, typename LayoutPolicy, typename AccessorPolicy>
void read_mdspan(const raft::handle_t& handle,
std::ifstream& file,
raft::device_mdspan<ElementType, Extents, LayoutPolicy, AccessorPolicy>&& obj)
{
read_mdspan(handle, file, obj);
}
} // namespace raft::spatial::knn::detail
95 changes: 95 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "../ivf_flat_types.hpp"
#include "ann_kmeans_balanced.cuh"
#include "ann_serialization.h"
#include "ann_utils.cuh"

#include <raft/core/handle.hpp>
Expand Down Expand Up @@ -378,4 +379,98 @@ inline void fill_refinement_index(const handle_t& handle,
refinement_index->veclen());
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

static const int serialization_version = 1;

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index_ IVF-Flat index
*
*/
template <typename T, typename IdxT>
void save(const handle_t& handle, const std::string& filename, const index<T, IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open %s", filename.c_str()); }

RAFT_LOG_DEBUG(
"Saving IVF-PQ index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());
write_scalar(of, serialization_version);
write_scalar(of, index_.size());
write_scalar(of, index_.dim());
write_scalar(of, index_.n_lists());
write_scalar(of, index_.metric());
write_scalar(of, index_.veclen());
write_scalar(of, index_.adaptive_centers());
write_mdspan(handle, of, index_.data());
write_mdspan(handle, of, index_.indices());
write_mdspan(handle, of, index_.list_sizes());
write_mdspan(handle, of, index_.list_offsets());
write_mdspan(handle, of, index_.centers());
if (index_.center_norms()) {
bool has_norms = true;
write_scalar(of, has_norms);
write_mdspan(handle, of, *index_.center_norms());
} else {
bool has_norms = false;
write_scalar(of, has_norms);
}
of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
}

/** Load an index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[in] index_ IVF-Flat index
*
*/
template <typename T, typename IdxT>
auto load(const handle_t& handle, const std::string& filename) -> index<T, IdxT>
{
std::ifstream infile(filename, std::ios::in | std::ios::binary);

if (!infile) { RAFT_FAIL("Cannot open %s", filename.c_str()); }

auto ver = read_scalar<int>(infile);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
}
auto n_rows = read_scalar<IdxT>(infile);
auto dim = read_scalar<uint32_t>(infile);
auto n_lists = read_scalar<uint32_t>(infile);
auto metric = read_scalar<raft::distance::DistanceType>(infile);
auto veclen = read_scalar<uint32_t>(infile);
bool adaptive_centers = read_scalar<bool>(infile);

index<T, IdxT> index_ =
raft::spatial::knn::ivf_flat::index<T, IdxT>(handle, metric, n_lists, adaptive_centers, dim);

index_.allocate(handle, n_rows, metric == raft::distance::DistanceType::L2Expanded);
auto data = index_.data();
read_mdspan(handle, infile, data);
read_mdspan(handle, infile, index_.indices());
read_mdspan(handle, infile, index_.list_sizes());
read_mdspan(handle, infile, index_.list_offsets());
read_mdspan(handle, infile, index_.centers());
bool has_norms = read_scalar<bool>(infile);
if (has_norms) {
if (!index_.center_norms()) {
RAFT_FAIL("Error inconsistent center norms");
} else {
auto center_norms = *index_.center_norms();
read_mdspan(handle, infile, center_norms);
}
}
infile.close();
return index_;
}
} // namespace raft::spatial::knn::ivf_flat::detail
106 changes: 106 additions & 0 deletions cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include "ann_kmeans_balanced.cuh"
#include "ann_serialization.h"
#include "ann_utils.cuh"

#include <raft/neighbors/ivf_pq_types.hpp>
Expand Down Expand Up @@ -1263,4 +1264,109 @@ inline auto build(
}}();
}

static const int serialization_version = 1;

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index_ IVF-PQ index
*
*/
template <typename IdxT>
void save(const handle_t& handle_, const std::string& filename, const index<IdxT>& index_)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

RAFT_LOG_DEBUG("Size %zu, dim %d, pq_dim %d, pq_bits %d",
static_cast<size_t>(index_.size()),
static_cast<int>(index_.dim()),
static_cast<int>(index_.pq_dim()),
static_cast<int>(index_.pq_bits()));

write_scalar(of, serialization_version);
write_scalar(of, index_.size());
write_scalar(of, index_.dim());
write_scalar(of, index_.pq_bits());
write_scalar(of, index_.pq_dim());

write_scalar(of, index_.metric());
write_scalar(of, index_.codebook_kind());
write_scalar(of, index_.n_lists());
write_scalar(of, index_.n_nonempty_lists());

write_mdspan(handle_, of, index_.pq_centers());
write_mdspan(handle_, of, index_.pq_dataset());
write_mdspan(handle_, of, index_.indices());
write_mdspan(handle_, of, index_.rotation_matrix());
write_mdspan(handle_, of, index_.list_offsets());
write_mdspan(handle_, of, index_.list_sizes());
write_mdspan(handle_, of, index_.centers());
write_mdspan(handle_, of, index_.centers_rot());

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
return;
}

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[in] index_ IVF-PQ index
*
*/
template <typename IdxT>
auto load(const handle_t& handle_, const std::string& filename) -> index<IdxT>
{
std::ifstream infile(filename, std::ios::in | std::ios::binary);

if (!infile) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

auto ver = read_scalar<int>(infile);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch %d vs. %d", ver, serialization_version);
}
auto n_rows = read_scalar<IdxT>(infile);
auto dim = read_scalar<uint32_t>(infile);
auto pq_bits = read_scalar<uint32_t>(infile);
auto pq_dim = read_scalar<uint32_t>(infile);

auto metric = read_scalar<raft::distance::DistanceType>(infile);
auto codebook_kind = read_scalar<raft::neighbors::ivf_pq::codebook_gen>(infile);
auto n_lists = read_scalar<uint32_t>(infile);
auto n_nonempty_lists = read_scalar<uint32_t>(infile);

RAFT_LOG_DEBUG("n_rows %zu, dim %d, pq_dim %d, pq_bits %d, n_lists %d",
static_cast<size_t>(n_rows),
static_cast<int>(dim),
static_cast<int>(pq_dim),
static_cast<int>(pq_bits),
static_cast<int>(n_lists));

auto index_ = raft::neighbors::ivf_pq::index<IdxT>(
handle_, metric, codebook_kind, n_lists, dim, pq_bits, pq_dim, n_nonempty_lists);
index_.allocate(handle_, n_rows);

read_mdspan(handle_, infile, index_.pq_centers());
read_mdspan(handle_, infile, index_.pq_dataset());
read_mdspan(handle_, infile, index_.indices());
read_mdspan(handle_, infile, index_.rotation_matrix());
read_mdspan(handle_, infile, index_.list_offsets());
read_mdspan(handle_, infile, index_.list_sizes());
read_mdspan(handle_, infile, index_.centers());
read_mdspan(handle_, infile, index_.centers_rot());

infile.close();

return index_;
}

} // namespace raft::spatial::knn::ivf_pq::detail
28 changes: 28 additions & 0 deletions cpp/include/raft_runtime/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,32 @@ RAFT_INST_BUILD_EXTEND(uint8_t, uint64_t)

#undef RAFT_INST_BUILD_EXTEND

/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the filename for saving the index
* @param[in] index IVF-PQ index
*
*/
void save(const handle_t& handle,
const std::string& filename,
const raft::neighbors::ivf_pq::index<uint64_t>& index);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] handle the raft handle
* @param[in] filename the name of the file that stores the index
* @param[in] index IVF-PQ index
*
*/
void load(const handle_t& handle,
const std::string& filename,
raft::neighbors::ivf_pq::index<uint64_t>* index);

} // namespace raft::runtime::neighbors::ivf_pq
Loading

0 comments on commit 96578a1

Please sign in to comment.