Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update CAGRA serialization #1755

Merged
merged 4 commits into from
Aug 21, 2023
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 11 additions & 5 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ namespace raft::neighbors::cagra {
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& handle, std::ostream& os, const index<T, IdxT>& index)
void serialize(raft::resources const& handle,
std::ostream& os,
const index<T, IdxT>& index,
bool include_dataset = true)
{
detail::serialize(handle, os, index);
detail::serialize(handle, os, index, include_dataset);
}

/**
Expand All @@ -77,14 +81,16 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T, I
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
const index<T, IdxT>& index,
bool include_dataset = true)
{
detail::serialize(handle, filename, index);
detail::serialize(handle, filename, index, include_dataset);
}

/**
Expand Down Expand Up @@ -158,4 +164,4 @@ namespace raft::neighbors::experimental::cagra {
using raft::neighbors::cagra::deserialize;
using raft::neighbors::cagra::serialize;

} // namespace raft::neighbors::experimental::cagra
} // namespace raft::neighbors::experimental::cagra
15 changes: 15 additions & 0 deletions cpp/include/raft/neighbors/cagra_types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,13 @@ struct index : ann::index {
dataset.data_handle(), dataset.extent(0), dataset.extent(1), dataset.extent(1));
}
}
void update_dataset(raft::resources const& res,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We probably want to keep these const mdspans. If this is because of python, can we use make_const_mdspan() in that layer?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that automatically discarding const would be bad - but this is doing the opposite and is automatically adding it (like this is converting a non-const mdspan to a const msdpan), which I feel like is something that should be allowed with our API's.

The issue I have is that Cython kinda sucks with respecting const identifiers, which is why all our Cython api's use non-const mdspans right now. Like if I try to add a get_const_hmv_float (to parallel the non-const get_hmv_float we have now) - I get an error message from Cython, where it doesn't recognize const float as a type inside template parameters:

      Error compiling Cython file:
      ------------------------------------------------------------
      ...
          if cai.dtype != np.float32:
              raise TypeError("dtype %s not supported" % cai.dtype)
          if check_shape and len(cai.shape) != 2:
              raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
          shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
          return make_host_matrix_view[const float, int64_t, row_major](
                                             ^
      ------------------------------------------------------------
      
      /home/ben/code/raft/python/pylibraft/pylibraft/common/mdspan.pyx:232:39: Expected ']', found 'float'

I can get around this by adding a Cython typedef (like ctypedef const float const_float) - but that introduces the need for other hacks later on (like cython will treat const_float and const float as separate types - meaning that when we define the update_dataset for Cython in c_cagra.pxd I can't just go const T as the type, and have to introduce a new template param =(. I've done this in the last commit - let me know what you think

raft::device_matrix_view<T, int64_t, row_major> dataset)
{
update_dataset(res,
make_device_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1)));
}

/**
* Replace the dataset with a new dataset.
Expand All @@ -271,6 +278,14 @@ struct index : ann::index {
copy_padded(res, dataset);
}

void update_dataset(raft::resources const& res,
raft::host_matrix_view<T, int64_t, row_major> dataset)
{
update_dataset(res,
make_host_matrix_view<const T, int64_t>(
dataset.data_handle(), dataset.extent(0), dataset.extent(1)));
}

/**
* Replace the graph with a new graph.
*
Expand Down
54 changes: 35 additions & 19 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

namespace raft::neighbors::cagra::detail {

// Serialization version 1.
constexpr int serialization_version = 2;
constexpr int serialization_version = 3;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand All @@ -50,41 +49,53 @@ template struct check_index_layout<sizeof(index<double, std::uint64_t>), expecte
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& res, std::ostream& os, const index<T, IdxT>& index_)
void serialize(raft::resources const& res,
std::ostream& os,
const index<T, IdxT>& index_,
bool include_dataset)
{
RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
dtype_string.resize(4);
os << dtype_string;

serialize_scalar(res, os, serialization_version);
serialize_scalar(res, os, index_.size());
serialize_scalar(res, os, index_.dim());
serialize_scalar(res, os, index_.graph_degree());
serialize_scalar(res, os, index_.metric());
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
serialize_mdspan(res, os, index_.graph());

serialize_scalar(res, os, include_dataset);
if (include_dataset) {
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
}
}

template <typename T, typename IdxT>
void serialize(raft::resources const& res,
const std::string& filename,
const index<T, IdxT>& index_)
const index<T, IdxT>& index_,
bool include_dataset)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

detail::serialize(res, of, index_);
detail::serialize(res, of, index_, include_dataset);

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
Expand All @@ -102,6 +113,9 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
char dtype_string[4];
is.read(dtype_string, 4);

auto ver = deserialize_scalar<int>(res, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
Expand All @@ -113,9 +127,11 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, dataset.view());
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
}
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft_runtime/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@ namespace raft::runtime::neighbors::cagra {
raft::device_matrix_view<float, int64_t, row_major> distances); \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/raft_runtime/neighbors/cagra_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ namespace raft::runtime::neighbors::cagra {
#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
raft::neighbors::cagra::serialize(handle, filename, index); \
raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \
}; \
\
void deserialize_file(raft::resources const& handle, \
Expand All @@ -41,10 +42,11 @@ namespace raft::runtime::neighbors::cagra {
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize(handle, os, index); \
raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \
str = os.str(); \
} \
\
Expand Down
13 changes: 10 additions & 3 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ struct AnnCagraInputs {
int search_width;
raft::distance::DistanceType metric;
bool host_dataset;
bool include_serialized_dataset;
// std::optional<double>
double min_recall; // = std::nullopt;
};
Expand Down Expand Up @@ -217,9 +218,11 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
} else {
index = cagra::build<DataT, IdxT>(handle_, index_params, database_view);
};
cagra::serialize(handle_, "cagra_index", index);
cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset);
}

auto index = cagra::deserialize<DataT, IdxT>(handle_, "cagra_index");
if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); }

auto search_queries_view = raft::make_device_matrix_view<const DataT, int64_t>(
search_queries.data(), ps.n_queries, ps.dim);
Expand Down Expand Up @@ -340,9 +343,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {

void SetUp() override
{
std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl;
database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream());
std::cout << "Done.\nRuning rng" << std::endl;
raft::random::Rng r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream());
Expand Down Expand Up @@ -379,6 +380,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});

auto inputs2 = raft::util::itertools::product<AnnCagraInputs>(
Expand All @@ -393,6 +395,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
inputs2 =
Expand All @@ -407,6 +410,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{false},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -422,6 +426,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -437,6 +442,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false, true},
{false},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -452,6 +458,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false, true},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand Down
Loading