Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add python serialization API's for ivf-pq and ivf_flat #186

Merged
merged 4 commits into from
Jun 14, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,50 @@ cuvsError_t cuvsIvfFlatSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup ivf_flat_c_serialize IVF-Flat C-API serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <cuvs/neighbors/ivf_flat.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsIvfFlatBuild`
* cuvsIvfFlatSerialize(res, "/path/to/index", index, true);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-Flat index
*/
cuvsError_t cuvsIvfFlatSerialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index IVF-Flat index loaded disk
*/
cuvsError_t cuvsIvfFlatDeserialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index);
/**
* @}
*/

#ifdef __cplusplus
}
#endif
9 changes: 9 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ struct index : cuvs::neighbors::index {
index& operator=(const index&) = delete;
index& operator=(index&&) = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& res);

/** Construct an empty index. It needs to be trained and then populated. */
index(raft::resources const& res, const index_params& params, uint32_t dim);
/** Construct an empty index. It needs to be trained and then populated. */
Expand Down
40 changes: 40 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,46 @@ cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup ivf_pq_c_serialize IVF-PQ C-API serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <cuvs/neighbors/ivf_pq.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsIvfPqBuild`
* cuvsIvfPqSerialize(res, "/path/to/index", index, true);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*/
cuvsError_t cuvsIvfPqSerialize(cuvsResources_t res, const char* filename, cuvsIvfPqIndex_t index);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index IVF-PQ index loaded disk
*/
cuvsError_t cuvsIvfPqDeserialize(cuvsResources_t res, const char* filename, cuvsIvfPqIndex_t index);
/**
* @}
*/

#ifdef __cplusplus
}
#endif
14 changes: 11 additions & 3 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,14 @@ struct index : cuvs::neighbors::index {
auto operator=(index&&) -> index& = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& handle);

/** Construct an empty index. It needs to be trained and then populated. */
index(raft::resources const& handle,
cuvs::distance::DistanceType metric,
Expand Down Expand Up @@ -1366,7 +1374,7 @@ void serialize(raft::resources const& handle,
*
* using IdxT = int64_t; // type of the index
* // create an empty index
* cuvs::neighbors::ivf_pq::index<IdxT> index(handl, index_params, dim);
* cuvs::neighbors::ivf_pq::index<IdxT> index(handle);
*
* cuvs::neighbors::ivf_pq::deserialize(handle, is, index);
* @endcode
Expand All @@ -1390,8 +1398,8 @@ void deserialize(raft::resources const& handle,
* // create a string with a filepath
* std::string filename("/path/to/index");
* using IdxT = int64_t; // type of the index
* // create an empty index with
* ivf_pq::index<IdxT> index(handle, index_params, dim);
* // create an empty index
* ivf_pq::index<IdxT> index(handle);
*
* cuvs::neighbors::ivf_pq::deserialize(handle, filename, &index);
* @endcode
Expand Down
68 changes: 34 additions & 34 deletions cpp/src/neighbors/detail/cagra/cagra_build.cpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/mdspan.hpp>
namespace cuvs::neighbors::cagra::graph_build_params {
ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
cuvs::distance::DistanceType metric)
{
build_params = cuvs::neighbors::ivf_pq::index_params::from_dataset(dataset_extents, metric);
search_params = cuvs::neighbors::ivf_pq::search_params{};
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;
refinement_rate = 2;
}
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/mdspan.hpp>

namespace cuvs::neighbors::cagra::graph_build_params {
ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
cuvs::distance::DistanceType metric)
{
build_params = cuvs::neighbors::ivf_pq::index_params::from_dataset(dataset_extents, metric);

search_params = cuvs::neighbors::ivf_pq::search_params{};
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;

refinement_rate = 2;
}
} // namespace cuvs::neighbors::cagra::graph_build_params
66 changes: 63 additions & 3 deletions cpp/src/neighbors/ivf_flat_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/error.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
Expand Down Expand Up @@ -83,6 +84,22 @@ void _search(cuvsResources_t res,
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
}

template <typename T, typename IdxT>
void _serialize(cuvsResources_t res, const char* filename, cuvsIvfFlatIndex index)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_flat::index<T, IdxT>*>(index.addr);
cuvs::neighbors::ivf_flat::serialize(*res_ptr, std::string(filename), *index_ptr);
}

template <typename T, typename IdxT>
void* _deserialize(cuvsResources_t res, const char* filename)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = new cuvs::neighbors::ivf_flat::index<T, IdxT>(*res_ptr);
cuvs::neighbors::ivf_flat::deserialize(*res_ptr, std::string(filename), index);
return index;
}
} // namespace

extern "C" cuvsError_t cuvsIvfFlatIndexCreate(cuvsIvfFlatIndex_t* index)
Expand Down Expand Up @@ -120,18 +137,16 @@ extern "C" cuvsError_t cuvsIvfFlatBuild(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;

index->dtype = dataset.dtype;
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float, int64_t>(res, *params, dataset_tensor));
index->dtype.code = kDLFloat;
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
index->addr =
reinterpret_cast<uintptr_t>(_build<int8_t, int64_t>(res, *params, dataset_tensor));
index->dtype.code = kDLInt;
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
index->addr =
reinterpret_cast<uintptr_t>(_build<uint8_t, int64_t>(res, *params, dataset_tensor));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
Expand Down Expand Up @@ -213,3 +228,48 @@ extern "C" cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t
{
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsIvfFlatDeserialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
// read the numpy dtype from the beginning of the file
std::ifstream is(filename, std::ios::in | std::ios::binary);
if (!is) { RAFT_FAIL("Cannot open file %s", filename); }
char dtype_string[4];
is.read(dtype_string, 4);
auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4));

index->dtype.bits = dtype.itemsize * 8;
if (dtype.kind == 'f' && dtype.itemsize == 4) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float, int64_t>(res, filename));
index->dtype.code = kDLFloat;
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t, int64_t>(res, filename));
index->dtype.code = kDLInt;
} else if (dtype.kind == 'u' && dtype.itemsize == 1) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t, int64_t>(res, filename));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dtype in file %s", filename);
benfred marked this conversation as resolved.
Show resolved Hide resolved
}
});
}

extern "C" cuvsError_t cuvsIvfFlatSerialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float, int64_t>(res, filename, *index);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t, int64_t>(res, filename, *index);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
_serialize<uint8_t, int64_t>(res, filename, *index);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
});
}
6 changes: 6 additions & 0 deletions cpp/src/neighbors/ivf_flat_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

namespace cuvs::neighbors::ivf_flat {

template <typename T, typename IdxT>
index<T, IdxT>::index(raft::resources const& res)
: index(res, cuvs::distance::DistanceType::L2Expanded, 0, false, false, 0)
{
}

template <typename T, typename IdxT>
index<T, IdxT>::index(raft::resources const& res, const index_params& params, uint32_t dim)
: index(res,
Expand Down
Loading
Loading