Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-PQ: manipulating individual lists #1298

Merged
merged 48 commits into from
Apr 17, 2023
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d421375
Add a method to reconstruct the compressed index data
achirkin Feb 23, 2023
c7b5574
Fix a typo in the docs
achirkin Feb 23, 2023
087919f
Relax the constraints a bit
achirkin Feb 23, 2023
b5e9844
Add the public interface
achirkin Feb 23, 2023
ea81c1e
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Feb 27, 2023
6567f5c
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Feb 28, 2023
3ccd48f
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 6, 2023
fdb8a6a
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 7, 2023
56ffc84
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 8, 2023
ebd1b1c
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 15, 2023
2773a86
Fix the merge errors
achirkin Mar 15, 2023
82be40c
Add an option to reconstruct cluster data by in-cluster indices
achirkin Mar 15, 2023
5fc1c53
Update the docs for the new function
achirkin Mar 15, 2023
ddc4466
Detach reconstruction logic from the data reading logic
achirkin Mar 15, 2023
56eb716
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 16, 2023
3a2c622
Implement unpack_list_data
achirkin Mar 16, 2023
e7c35e7
Public interface for unpack_list_data
achirkin Mar 16, 2023
839cd48
Fix (unrelated to the PR) test condition being just a little bit too …
achirkin Mar 16, 2023
46d5a84
Fix a miswording in the docs
achirkin Mar 16, 2023
9fc9818
Add public api for extending individual lists
achirkin Mar 16, 2023
2f1f6f7
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 16, 2023
ac8aa18
Implemented pack_list_data
achirkin Mar 16, 2023
d5b2d1b
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 16, 2023
f62273d
Reuse write_vector inside the process_and_fill_codes_kernel
achirkin Mar 16, 2023
38895bb
Initial implementation of extend_list (failing tests)
achirkin Mar 17, 2023
a674768
Adjust the scheduling of the encode_list_data_kernel
achirkin Mar 17, 2023
1be2f6b
Factor code-packing out of the build file
achirkin Mar 17, 2023
d6cad17
Fix failing tests
achirkin Mar 17, 2023
2625666
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 17, 2023
172673b
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 17, 2023
1504027
Relax the test criterion eps a little bit
achirkin Mar 17, 2023
1326ac7
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 19, 2023
08cfacb
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 20, 2023
a399023
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 21, 2023
d935568
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
cjnolet Mar 23, 2023
d1dd238
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
cjnolet Mar 23, 2023
5514d7d
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 25, 2023
a317622
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 25, 2023
9367936
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 28, 2023
09f1a0d
Move ivf_pq helpers to separate file and namespace
tfeher Mar 31, 2023
b7e811a
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
tfeher Mar 31, 2023
926f510
Add public API for pack_list_data
tfeher Mar 31, 2023
8e311d9
Increase tolarance for vector reconstruction test
tfeher Apr 2, 2023
158189c
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
tfeher Apr 3, 2023
280c040
Correct number of list elements to compare
tfeher Apr 3, 2023
a7e5419
Add ivf_pq::helpers::codepacker::pack / unpack
tfeher Apr 5, 2023
09abadb
Merge branch 'branch-23.06' into fea-ivf-pq-reconstruct
cjnolet Apr 12, 2023
923c3b5
Merge branch 'branch-23.06' into fea-ivf-pq-reconstruct
cjnolet Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
286 changes: 250 additions & 36 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -557,6 +557,178 @@ void train_per_cluster(raft::device_resources const& handle,
transpose_pq_centers(handle, index, pq_centers_tmp.data());
}

/**
* Decode a lvl-2 pq-encoded vector in the given list (cluster).
* One vector per thread.
* NB: this function only decodes the PQ (second level) encoding; to get the approximation of the
* original vector, you need to add the cluster centroid and apply the inverse matrix transform to
* the result of this function.
*
* @tparam PqBits
*
* @param[out] out_vector the destination for the decoded vector (one-per-thread).
* @param[in] in_list_data the encoded cluster data.
* @param[in] pq_centers the codebook
* @param[in] codebook_kind
* @param[in] in_ix in-cluster index of the vector to be decoded (one-per-thread).
* @param[in] cluster_ix label/id of the cluster (one-per-thread).
*/
template <uint32_t PqBits>
__device__ void reconstruct_vector(
device_vector_view<float, uint32_t, row_major> out_vector,
device_mdspan<const uint8_t, list_spec<uint32_t>::list_extents, row_major> in_list_data,
device_mdspan<const float, extent_3d<uint32_t>, row_major> pq_centers,
codebook_gen codebook_kind,
uint32_t in_ix,
uint32_t cluster_ix)
{
using group_align = Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(in_ix);
const uint32_t ingroup_ix = group_align::mod(in_ix);
const uint32_t pq_len = pq_centers.extent(1);
const uint32_t pq_dim = out_vector.extent(0) / pq_len;

using layout_t = typename decltype(out_vector)::layout_type;
using accessor_t = typename decltype(out_vector)::accessor_type;
auto reinterpreted_vector = mdspan<float, extent_2d<uint32_t>, layout_t, accessor_t>(
out_vector.data_handle(), extent_2d<uint32_t>{pq_dim, pq_len});

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;
for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// read the chunk
code_chunk = *reinterpret_cast<const pq_vec_t*>(&in_list_data(group_ix, i, ingroup_ix, 0));
// read the codes, one/pq_dim at a time
#pragma unroll
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
uint32_t partition_ix;
switch (codebook_kind) {
case codebook_gen::PER_CLUSTER: {
partition_ix = cluster_ix;
} break;
case codebook_gen::PER_SUBSPACE: {
partition_ix = j;
} break;
default: __builtin_unreachable();
}
uint8_t code = code_view[k];
// read a piece of the reconstructed vector
for (uint32_t l = 0; l < pq_len; l++) {
reinterpreted_vector(j, l) = pq_centers(partition_ix, l, code);
}
}
}
}

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) __global__ void reconstruct_list_data_kernel(
device_matrix_view<float, uint32_t, row_major> out_vectors,
device_vector_view<const uint8_t* const, uint32_t, row_major> data_ptrs,
device_mdspan<const float, extent_3d<uint32_t>, row_major> pq_centers,
device_matrix_view<const float, uint32_t, row_major> centers_rot,
codebook_gen codebook_kind,
uint32_t cluster_ix,
uint32_t n_skip)
{
const auto out_dim = out_vectors.extent(1);
using layout_t = typename decltype(out_vectors)::layout_type;
using accessor_t = typename decltype(out_vectors)::accessor_type;

const uint32_t pq_dim = out_dim / pq_centers.extent(1);
auto pq_extents =
list_spec<uint32_t>{PqBits, pq_dim, true}.make_list_extents(out_vectors.extent(0) + n_skip + 1);
auto pq_dataset =
make_mdspan<const uint8_t, uint32_t, row_major, false, true>(data_ptrs[cluster_ix], pq_extents);
achirkin marked this conversation as resolved.
Show resolved Hide resolved

for (uint32_t ix = threadIdx.x + BlockSize * blockIdx.x; ix < out_vectors.extent(0);
ix += BlockSize) {
auto one_vector = mdspan<float, extent_1d<uint32_t>, layout_t, accessor_t>(
&out_vectors(ix, 0), extent_1d<uint32_t>{out_vectors.extent(1)});
reconstruct_vector<PqBits>(
one_vector, pq_dataset, pq_centers, codebook_kind, ix + n_skip, cluster_ix);
for (uint32_t j = 0; j < out_dim; j++) {
one_vector(j) += centers_rot(cluster_ix, j);
}
}
}

/** Decode the list data; see the public interface for the api and usage. */
template <typename T, typename IdxT>
void reconstruct_list_data(raft::device_resources const& res,
const index<IdxT>& index,
device_matrix_view<T, uint32_t, row_major> out_vectors,
uint32_t label,
uint32_t n_skip)
{
auto n_rows = out_vectors.extent(0);
if (n_rows == 0) { return; }
// sic! I'm using the upper bound `list.size` instead of exact `list_sizes(label)`
// to avoid an extra device-host data copy and the stream sync.
RAFT_EXPECTS(n_skip + n_rows <= index.lists()[label]->size.load(),
"n_skip + output size must be not bigger than the cluster size.");

auto tmp = make_device_mdarray<float>(
res, res.get_workspace_resource(), make_extents<uint32_t>(n_rows, index.rot_dim()));

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [](uint32_t pq_bits) {
switch (pq_bits) {
case 4: return reconstruct_list_data_kernel<kBlockSize, 4>;
case 5: return reconstruct_list_data_kernel<kBlockSize, 5>;
case 6: return reconstruct_list_data_kernel<kBlockSize, 6>;
case 7: return reconstruct_list_data_kernel<kBlockSize, 7>;
case 8: return reconstruct_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}(index.pq_bits());
kernel<<<blocks, threads, 0, res.get_stream()>>>(tmp.view(),
index.data_ptrs(),
index.pq_centers(),
index.centers_rot(),
index.codebook_kind(),
label,
n_skip);
RAFT_CUDA_TRY(cudaPeekAtLastError());

float* out_float_ptr = nullptr;
rmm::device_uvector<float> out_float_buf(0, res.get_stream(), res.get_workspace_resource());
if constexpr (std::is_same_v<T, float>) {
out_float_ptr = out_vectors.data_handle();
} else {
out_float_buf.resize(size_t{n_rows} * size_t{index.dim()}, res.get_stream());
out_float_ptr = out_float_buf.data();
}
// Rotate the results back to the original space
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(res,
false,
false,
index.dim(),
n_rows,
index.rot_dim(),
&alpha,
index.rotation_matrix().data_handle(),
index.dim(),
tmp.data_handle(),
index.rot_dim(),
&beta,
out_float_ptr,
index.dim(),
res.get_stream());
// Transform the data to the original type, if necessary
if constexpr (!std::is_same_v<T, float>) {
linalg::map_k(out_vectors.data_handle(),
out_float_buf.size(),
utils::mapping<T>{},
res.get_stream(),
out_float_ptr);
}
}

/**
* Compute the code: find the closest cluster in each pq_dim-subspace.
*
Expand Down Expand Up @@ -625,6 +797,67 @@ __device__ auto compute_pq_code(
return code;
}

/**
* Compute a PQ code for a single input vector per subwarp and write it into the
* appropriate cluster.
* Subwarp size here is the minimum between WarpSize and the codebook size.
*
* @tparam BlockSize
* @tparam PqBits
*
* @param[out] out_list_data an array of pointers to the database clusers.
* @param[in] in_vector input unencoded data, one-per-subwarp
* @param[in] pq_centers codebook
* @param[in] codebook_kind
* @param[in] out_ix in-cluster output index (where to write the encoded data), one-per-subwarp.
* @param[in] cluster_ix label/id of the cluster to fill, one-per-subwarp.
*/
template <uint32_t BlockSize, uint32_t PqBits>
__device__ auto compute_and_write_pq_code(
device_mdspan<uint8_t, list_spec<uint32_t>::list_extents, row_major> out_list_data,
device_vector_view<const float, uint32_t, row_major> in_vector,
device_mdspan<const float, extent_3d<uint32_t>, row_major> pq_centers,
codebook_gen codebook_kind,
uint32_t out_ix,
uint32_t cluster_ix)
{
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(WarpSize, 1u << PqBits);
using subwarp_align = Pow2<kSubWarpSize>;
const uint32_t lane_id = subwarp_align::mod(threadIdx.x);

using group_align = Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(out_ix);
const uint32_t ingroup_ix = group_align::mod(out_ix);
const uint32_t pq_len = pq_centers.extent(1);
const uint32_t pq_dim = in_vector.extent(0) / pq_len;

using layout_t = typename decltype(in_vector)::layout_type;
using accessor_t = typename decltype(in_vector)::accessor_type;
auto reinterpreted_vector = mdspan<const float, extent_2d<uint32_t>, layout_t, accessor_t>(
in_vector.data_handle(), extent_2d<uint32_t>{pq_dim, pq_len});

__shared__ pq_vec_t codes[subwarp_align::div(BlockSize)];
pq_vec_t& code = codes[subwarp_align::div(threadIdx.x)];
bitfield_view_t<PqBits> out{reinterpret_cast<uint8_t*>(&code)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;
for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// clear the chunk for writing
if (lane_id == 0) { code = pq_vec_t{}; }
// fill-in the values, one/pq_dim at a time
#pragma unroll
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
// find the label
auto l = compute_pq_code<kSubWarpSize>(
pq_centers, reinterpreted_vector, codebook_kind, j, cluster_ix);
if (lane_id == 0) { out[k] = l; }
}
// write the chunk into the dataset
if (lane_id == 0) {
*reinterpret_cast<pq_vec_t*>(&out_list_data(group_ix, i, ingroup_ix, 0)) = code;
}
}
}

template <uint32_t BlockSize, uint32_t PqBits, typename IdxT>
__launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel(
device_matrix_view<const float, IdxT, row_major> new_vectors,
Expand All @@ -639,15 +872,15 @@ __launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel(
constexpr uint32_t kSubWarpSize = std::min<uint32_t>(WarpSize, 1u << PqBits);
using subwarp_align = Pow2<kSubWarpSize>;
const uint32_t lane_id = subwarp_align::mod(threadIdx.x);
const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{blockDim.x} * IdxT{blockIdx.x});
const IdxT row_ix = subwarp_align::div(IdxT{threadIdx.x} + IdxT{BlockSize} * IdxT{blockIdx.x});
if (row_ix >= new_vectors.extent(0)) { return; }

const uint32_t cluster_ix = new_labels[row_ix];
uint32_t out_ix;
if (lane_id == 0) { out_ix = atomicAdd(&list_sizes(cluster_ix), 1); }
out_ix = shfl(out_ix, 0, kSubWarpSize);

// write the label
// write the label (one record per subwarp)
auto pq_indices = inds_ptrs(cluster_ix);
if (lane_id == 0) {
if (std::holds_alternative<IdxT>(src_offset_or_indices)) {
Expand All @@ -657,40 +890,21 @@ __launch_bounds__(BlockSize) __global__ void process_and_fill_codes_kernel(
}
}

// write the codes
using group_align = Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(out_ix);
const uint32_t ingroup_ix = group_align::mod(out_ix);
const uint32_t pq_len = pq_centers.extent(1);
const uint32_t pq_dim = new_vectors.extent(1) / pq_len;

auto pq_extents = list_spec<uint32_t>{PqBits, pq_dim, true}.make_list_extents(out_ix + 1);
auto pq_extents_vectorized =
make_extents<uint32_t>(pq_extents.extent(0), pq_extents.extent(1), pq_extents.extent(2));
auto pq_dataset = make_mdspan<pq_vec_t, uint32_t, row_major, false, true>(
reinterpret_cast<pq_vec_t*>(data_ptrs[cluster_ix]), pq_extents_vectorized);

__shared__ pq_vec_t codes[subwarp_align::div(BlockSize)];
pq_vec_t& code = codes[subwarp_align::div(threadIdx.x)];
bitfield_view_t<PqBits> out{reinterpret_cast<uint8_t*>(&code)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;
for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// clear the chunk for writing
if (lane_id == 0) { code = pq_vec_t{}; }
// fill-in the values, one/pq_dim at a time
#pragma unroll
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
// find the label
using layout_t = typename decltype(new_vectors)::layout_type;
using accessor_t = typename decltype(new_vectors)::accessor_type;
auto one_vector = mdspan<const float, extent_2d<uint32_t>, layout_t, accessor_t>(
&new_vectors(row_ix, 0), extent_2d<uint32_t>{pq_dim, pq_len});
auto l = compute_pq_code<kSubWarpSize>(pq_centers, one_vector, codebook_kind, j, cluster_ix);
if (lane_id == 0) { out[k] = l; }
}
// write the chunk into the dataset
if (lane_id == 0) { pq_dataset(group_ix, i, ingroup_ix) = code; }
}
// write the codes (one record per subwarp):
// 1. select input row
using layout_t = typename decltype(new_vectors)::layout_type;
using accessor_t = typename decltype(new_vectors)::accessor_type;
const auto in_dim = new_vectors.extent(1);
auto one_vector =
mdspan<const float, extent_1d<uint32_t>, layout_t, accessor_t>(&new_vectors(row_ix, 0), in_dim);
// 2. select output cluster
const uint32_t pq_dim = in_dim / pq_centers.extent(1);
auto pq_extents = list_spec<uint32_t>{PqBits, pq_dim, true}.make_list_extents(out_ix + 1);
auto pq_dataset =
make_mdspan<uint8_t, uint32_t, row_major, false, true>(data_ptrs[cluster_ix], pq_extents);
// 3. compute and write the vector
compute_and_write_pq_code<BlockSize, PqBits>(
pq_dataset, one_vector, pq_centers, codebook_kind, out_ix, cluster_ix);
}

/**
Expand Down
42 changes: 42 additions & 0 deletions cpp/include/raft/neighbors/ivf_pq.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,48 @@ auto build(raft::device_resources const& handle,
return detail::build(handle, params, dataset, n_rows, dim);
}

/**
* @brief Decode `n_take` consecutive records of a single list (cluster) in the compressed index
* starting at given offset `n_skip`.
*
* Usage example:
* @code{.cpp}
* // We will reconstruct the fourth cluster
* uint32_t label = 3;
* // Get the list size
* uint32_t list_size = 0;
* raft::copy(&list_size, index.list_sizes().data_handle() + label, 1, res.get_stream());
* res.sync_stream();
* // allocate the buffer for the output
* auto decoded_vectors = raft::make_device_matrix<float>(res, list_size, index.dim());
* // decode the whole list
* ivf_pq::reconstruct_list_data(res, index, decoded_vectors.view(), label, 0);
* @endcode
*
* @tparam T data element type
* @tparam IdxT type of the indices in the source dataset
*
* @param[in] res
* @param[in] index
* @param[out] out_vectors
* the destination buffer [n_take, index.dim()].
* The length `n_take` defines how many records to reconstruct,
* it must be smaller than the list size.
* @param[in] label
* The id of the list (cluster) to decode.
* @param[in] n_skip
achirkin marked this conversation as resolved.
Show resolved Hide resolved
* How many records in the list to skip.
*/
template <typename T, typename IdxT>
void reconstruct_list_data(raft::device_resources const& res,
const index<IdxT>& index,
device_matrix_view<T, uint32_t, row_major> out_vectors,
uint32_t label,
uint32_t n_skip)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@cjnolet , could you please check if this interface is suitable for integration with FAISS?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A couple important points that the FAISS developers have pointed out:

  1. We need to be sure that we don't assume indexes are drawn from a strictly monotonically increasing set.
  2. Following from 1, what's been asked for is a reconstruct_batch API which accepts a collection of ids (which are not necessarily contiguous) and reconstructs those vectors.

From what I can tell, the API you are proposing here doesn't implement the above, so I would say it's not suitable for the FAISS integration.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A quick clarification: the ids we should pass are (A) the same indices we pass during construction or (B) internal ids of the ivf-pq index?
If (B) then should these be a pair (label, collection of in-cluster-ids), or a collection of (label, in-cluster-id)? The former, perhaps can be slightly faster.
If (A) then I hardly can imagine an efficient function that can do this: we don't sort or organize the input indices in any way, so a full search in the database needs to be done for every index. Perhaps we can split the task in two: search indices -> then call interface (B)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also don't you need something like this interface to be able to convert our GPU index to their CPU index type (more-or-less efficiently, converting one cluster at a time)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also don't you need something like this interface to be able to convert our GPU index to their CPU index type (more-or-less efficiently, converting one cluster at a time)?

There's a couple different asks here. We have an ask to be able to convert from one interleaved (and potentially non-interleaved) encoded format to another, which is required for being able to convert between cpu and gpu indices. That ask is for us to have functions available that can be used by their CodePacker API.

In addition to this, there's the ask to be able to reconstruct a batch of vectors given a set of indices, which might not be in order. This is the corresponding issue for that ask (which is directly from Jeff Johnson). The indices passed into the reconstruct_batch would also coincide w/ the indices which would be specified when invoking extend here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That says to the world "We will not change these APIs, nor the way they function, without deprecation first".

Oh, now I see what you mean. I agree, we should export detail::pack_list_data and detail::unpack_list_data into a public namespace, to make sure we don't change their signature and break faiss integration by accident. I'm also not opposed to the idea of moving the construct/reconstruct and other similar functions to the ivf_pq::helpers public namespace.
The only thing that bothers me, is that the overloads we need for the codepacker go really against the current design of the public api. Maybe, as an exception, we export these two into something like ivf_pq::codepacker to discourage other users from using them? Or even implement the whole codepacker api in raft in that namespace?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

to discourage other users from using them?

I really don't think this is a good way to approach any public API. Either the API is public and users can use it, or it's not and they should touch that code at their own risk. Trying to force this middle ground for some users and not others is not sustainable and I'm not sure why there's a problem just exposing it outright.

I also don't agree that these functions should not just be part of the public API- if users need the ability to call a function, it needs to wrapped into the public API so it can be clearly documented, not just exported.

We had tried the exporting initially when we first started exposing formal public APIs and that was a mistake. Anything the user should be interacting directly with should be available and clearly documented (which for most almost all of the functions in RAFT just means wrapped).

I'm not opposed to using a nested "CodePacker" namespace but I'd also prefer that it is also nested inside helpers, so we be consistent there wrt the standard API namespaces and the "extra" functionality.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added a public API for pack_list_data.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note, for the codepacker, we'd probably need to expose this overload

inline void unpack_list_data(
device_matrix_view<uint8_t, uint32_t, row_major> codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)

Copy link
Contributor

@tfeher tfeher Apr 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have missed this message before. I have exposed the pack/unpack versions that take list_data directly under ivf_pq::helpers::codepacker namespace.

In our implementation we still need to provide the pq_bits parameter. Otherwise it is very similar to the CodePacker API.

One question: what does block refer to in CodePacker? In our API it refers to a set of interleaved vectors, which can be the data for a whole cluster, or less.

{
return detail::reconstruct_list_data(res, index, out_vectors, label, n_skip);
}

/**
* @brief Build a new index containing the data of the original plus new extra vectors.
*
Expand Down
Loading