Skip to content

Commit

Permalink
Add python serialization API's for ivf-pq and ivf_flat (rapidsai#186)
Browse files Browse the repository at this point in the history
Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: rapidsai/cuvs#186
  • Loading branch information
benfred authored Jun 14, 2024
1 parent 615749e commit 9dc3a4d
Show file tree
Hide file tree
Showing 20 changed files with 538 additions and 142 deletions.
44 changes: 44 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.h
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,50 @@ cuvsError_t cuvsIvfFlatSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup ivf_flat_c_serialize IVF-Flat C-API serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <cuvs/neighbors/ivf_flat.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsIvfFlatBuild`
* cuvsIvfFlatSerialize(res, "/path/to/index", index, true);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-Flat index
*/
cuvsError_t cuvsIvfFlatSerialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index IVF-Flat index loaded disk
*/
cuvsError_t cuvsIvfFlatDeserialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index);
/**
* @}
*/

#ifdef __cplusplus
}
#endif
9 changes: 9 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_flat.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ struct index : cuvs::neighbors::index {
index& operator=(const index&) = delete;
index& operator=(index&&) = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& res);

/** Construct an empty index. It needs to be trained and then populated. */
index(raft::resources const& res, const index_params& params, uint32_t dim);
/** Construct an empty index. It needs to be trained and then populated. */
Expand Down
40 changes: 40 additions & 0 deletions cpp/include/cuvs/neighbors/ivf_pq.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,46 @@ cuvsError_t cuvsIvfPqSearch(cuvsResources_t res,
* @}
*/

/**
* @defgroup ivf_pq_c_serialize IVF-PQ C-API serialize functions
* @{
*/
/**
* Save the index to file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @code{.cpp}
* #include <cuvs/neighbors/ivf_pq.h>
*
* // Create cuvsResources_t
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* // create an index with `cuvsIvfPqBuild`
* cuvsIvfPqSerialize(res, "/path/to/index", index, true);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the file name for saving the index
* @param[in] index IVF-PQ index
*/
cuvsError_t cuvsIvfPqSerialize(cuvsResources_t res, const char* filename, cuvsIvfPqIndex_t index);

/**
* Load index from file.
*
* Experimental, both the API and the serialization format are subject to change.
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] filename the name of the file that stores the index
* @param[out] index IVF-PQ index loaded disk
*/
cuvsError_t cuvsIvfPqDeserialize(cuvsResources_t res, const char* filename, cuvsIvfPqIndex_t index);
/**
* @}
*/

#ifdef __cplusplus
}
#endif
14 changes: 11 additions & 3 deletions cpp/include/cuvs/neighbors/ivf_pq.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,14 @@ struct index : cuvs::neighbors::index {
auto operator=(index&&) -> index& = default;
~index() = default;

/**
* @brief Construct an empty index.
*
* Constructs an empty index. This index will either need to be trained with `build`
* or loaded from a saved copy with `deserialize`
*/
index(raft::resources const& handle);

/** Construct an empty index. It needs to be trained and then populated. */
index(raft::resources const& handle,
cuvs::distance::DistanceType metric,
Expand Down Expand Up @@ -1366,7 +1374,7 @@ void serialize(raft::resources const& handle,
*
* using IdxT = int64_t; // type of the index
* // create an empty index
* cuvs::neighbors::ivf_pq::index<IdxT> index(handl, index_params, dim);
* cuvs::neighbors::ivf_pq::index<IdxT> index(handle);
*
* cuvs::neighbors::ivf_pq::deserialize(handle, is, index);
* @endcode
Expand All @@ -1390,8 +1398,8 @@ void deserialize(raft::resources const& handle,
* // create a string with a filepath
* std::string filename("/path/to/index");
* using IdxT = int64_t; // type of the index
* // create an empty index with
* ivf_pq::index<IdxT> index(handle, index_params, dim);
* // create an empty index
* ivf_pq::index<IdxT> index(handle);
*
* cuvs::neighbors::ivf_pq::deserialize(handle, filename, &index);
* @endcode
Expand Down
68 changes: 34 additions & 34 deletions cpp/src/neighbors/detail/cagra/cagra_build.cpp
Original file line number Diff line number Diff line change
@@ -1,35 +1,35 @@
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/mdspan.hpp>

namespace cuvs::neighbors::cagra::graph_build_params {
ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
cuvs::distance::DistanceType metric)
{
build_params = cuvs::neighbors::ivf_pq::index_params::from_dataset(dataset_extents, metric);

search_params = cuvs::neighbors::ivf_pq::search_params{};
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;

refinement_rate = 2;
}
/*
* Copyright (c) 2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <raft/core/mdspan.hpp>

namespace cuvs::neighbors::cagra::graph_build_params {
ivf_pq_params::ivf_pq_params(raft::matrix_extent<int64_t> dataset_extents,
cuvs::distance::DistanceType metric)
{
build_params = cuvs::neighbors::ivf_pq::index_params::from_dataset(dataset_extents, metric);

search_params = cuvs::neighbors::ivf_pq::search_params{};
search_params.n_probes = std::max<uint32_t>(10, build_params.n_lists * 0.01);
search_params.lut_dtype = CUDA_R_16F;
search_params.internal_distance_dtype = CUDA_R_16F;

refinement_rate = 2;
}
} // namespace cuvs::neighbors::cagra::graph_build_params
67 changes: 64 additions & 3 deletions cpp/src/neighbors/ivf_flat_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/error.hpp>
#include <raft/core/mdspan_types.hpp>
#include <raft/core/resources.hpp>
#include <raft/core/serialize.hpp>

#include <cuvs/core/c_api.h>
#include <cuvs/core/exceptions.hpp>
Expand Down Expand Up @@ -83,6 +84,22 @@ void _search(cuvsResources_t res,
*res_ptr, search_params, *index_ptr, queries_mds, neighbors_mds, distances_mds);
}

template <typename T, typename IdxT>
void _serialize(cuvsResources_t res, const char* filename, cuvsIvfFlatIndex index)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index_ptr = reinterpret_cast<cuvs::neighbors::ivf_flat::index<T, IdxT>*>(index.addr);
cuvs::neighbors::ivf_flat::serialize(*res_ptr, std::string(filename), *index_ptr);
}

template <typename T, typename IdxT>
void* _deserialize(cuvsResources_t res, const char* filename)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = new cuvs::neighbors::ivf_flat::index<T, IdxT>(*res_ptr);
cuvs::neighbors::ivf_flat::deserialize(*res_ptr, std::string(filename), index);
return index;
}
} // namespace

extern "C" cuvsError_t cuvsIvfFlatIndexCreate(cuvsIvfFlatIndex_t* index)
Expand Down Expand Up @@ -120,18 +137,16 @@ extern "C" cuvsError_t cuvsIvfFlatBuild(cuvsResources_t res,
return cuvs::core::translate_exceptions([=] {
auto dataset = dataset_tensor->dl_tensor;

index->dtype = dataset.dtype;
if (dataset.dtype.code == kDLFloat && dataset.dtype.bits == 32) {
index->addr =
reinterpret_cast<uintptr_t>(_build<float, int64_t>(res, *params, dataset_tensor));
index->dtype.code = kDLFloat;
} else if (dataset.dtype.code == kDLInt && dataset.dtype.bits == 8) {
index->addr =
reinterpret_cast<uintptr_t>(_build<int8_t, int64_t>(res, *params, dataset_tensor));
index->dtype.code = kDLInt;
} else if (dataset.dtype.code == kDLUInt && dataset.dtype.bits == 8) {
index->addr =
reinterpret_cast<uintptr_t>(_build<uint8_t, int64_t>(res, *params, dataset_tensor));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL("Unsupported dataset DLtensor dtype: %d and bits: %d",
dataset.dtype.code,
Expand Down Expand Up @@ -213,3 +228,49 @@ extern "C" cuvsError_t cuvsIvfFlatSearchParamsDestroy(cuvsIvfFlatSearchParams_t
{
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsIvfFlatDeserialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
// read the numpy dtype from the beginning of the file
std::ifstream is(filename, std::ios::in | std::ios::binary);
if (!is) { RAFT_FAIL("Cannot open file %s", filename); }
char dtype_string[4];
is.read(dtype_string, 4);
auto dtype = raft::detail::numpy_serializer::parse_descr(std::string(dtype_string, 4));

index->dtype.bits = dtype.itemsize * 8;
if (dtype.kind == 'f' && dtype.itemsize == 4) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<float, int64_t>(res, filename));
index->dtype.code = kDLFloat;
} else if (dtype.kind == 'i' && dtype.itemsize == 1) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<int8_t, int64_t>(res, filename));
index->dtype.code = kDLInt;
} else if (dtype.kind == 'u' && dtype.itemsize == 1) {
index->addr = reinterpret_cast<uintptr_t>(_deserialize<uint8_t, int64_t>(res, filename));
index->dtype.code = kDLUInt;
} else {
RAFT_FAIL(
"Unsupported dtype in file %s itemsize %i kind %i", filename, dtype.itemsize, dtype.kind);
}
});
}

extern "C" cuvsError_t cuvsIvfFlatSerialize(cuvsResources_t res,
const char* filename,
cuvsIvfFlatIndex_t index)
{
return cuvs::core::translate_exceptions([=] {
if (index->dtype.code == kDLFloat && index->dtype.bits == 32) {
_serialize<float, int64_t>(res, filename, *index);
} else if (index->dtype.code == kDLInt && index->dtype.bits == 8) {
_serialize<int8_t, int64_t>(res, filename, *index);
} else if (index->dtype.code == kDLUInt && index->dtype.bits == 8) {
_serialize<uint8_t, int64_t>(res, filename, *index);
} else {
RAFT_FAIL("Unsupported index dtype: %d and bits: %d", index->dtype.code, index->dtype.bits);
}
});
}
6 changes: 6 additions & 0 deletions cpp/src/neighbors/ivf_flat_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@

namespace cuvs::neighbors::ivf_flat {

template <typename T, typename IdxT>
index<T, IdxT>::index(raft::resources const& res)
: index(res, cuvs::distance::DistanceType::L2Expanded, 0, false, false, 0)
{
}

template <typename T, typename IdxT>
index<T, IdxT>::index(raft::resources const& res, const index_params& params, uint32_t dim)
: index(res,
Expand Down
Loading

0 comments on commit 9dc3a4d

Please sign in to comment.