Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IVF-PQ: manipulating individual lists #1298

Merged
merged 48 commits into from
Apr 17, 2023
Merged
Changes from 1 commit
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
d421375
Add a method to reconstruct the compressed index data
achirkin Feb 23, 2023
c7b5574
Fix a typo in the docs
achirkin Feb 23, 2023
087919f
Relax the constraints a bit
achirkin Feb 23, 2023
b5e9844
Add the public interface
achirkin Feb 23, 2023
ea81c1e
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Feb 27, 2023
6567f5c
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Feb 28, 2023
3ccd48f
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 6, 2023
fdb8a6a
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 7, 2023
56ffc84
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 8, 2023
ebd1b1c
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 15, 2023
2773a86
Fix the merge errors
achirkin Mar 15, 2023
82be40c
Add an option to reconstruct cluster data by in-cluster indices
achirkin Mar 15, 2023
5fc1c53
Update the docs for the new function
achirkin Mar 15, 2023
ddc4466
Detach reconstruction logic from the data reading logic
achirkin Mar 15, 2023
56eb716
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 16, 2023
3a2c622
Implement unpack_list_data
achirkin Mar 16, 2023
e7c35e7
Public interface for unpack_list_data
achirkin Mar 16, 2023
839cd48
Fix (unrelated to the PR) test condition being just a little bit too …
achirkin Mar 16, 2023
46d5a84
Fix a miswording in the docs
achirkin Mar 16, 2023
9fc9818
Add public api for extending individual lists
achirkin Mar 16, 2023
2f1f6f7
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 16, 2023
ac8aa18
Implemented pack_list_data
achirkin Mar 16, 2023
d5b2d1b
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 16, 2023
f62273d
Reuse write_vector inside the process_and_fill_codes_kernel
achirkin Mar 16, 2023
38895bb
Initial implementation of extend_list (failing tests)
achirkin Mar 17, 2023
a674768
Adjust the scheduling of the encode_list_data_kernel
achirkin Mar 17, 2023
1be2f6b
Factor code-packing out of the build file
achirkin Mar 17, 2023
d6cad17
Fix failing tests
achirkin Mar 17, 2023
2625666
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 17, 2023
172673b
Merge remote-tracking branch 'rapidsai/branch-23.04' into fea-ivf-pq-…
achirkin Mar 17, 2023
1504027
Relax the test criterion eps a little bit
achirkin Mar 17, 2023
1326ac7
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 19, 2023
08cfacb
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 20, 2023
a399023
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 21, 2023
d935568
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
cjnolet Mar 23, 2023
d1dd238
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
cjnolet Mar 23, 2023
5514d7d
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 25, 2023
a317622
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 25, 2023
9367936
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
achirkin Mar 28, 2023
09f1a0d
Move ivf_pq helpers to separate file and namespace
tfeher Mar 31, 2023
b7e811a
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
tfeher Mar 31, 2023
926f510
Add public API for pack_list_data
tfeher Mar 31, 2023
8e311d9
Increase tolarance for vector reconstruction test
tfeher Apr 2, 2023
158189c
Merge branch 'branch-23.04' into fea-ivf-pq-reconstruct
tfeher Apr 3, 2023
280c040
Correct number of list elements to compare
tfeher Apr 3, 2023
a7e5419
Add ivf_pq::helpers::codepacker::pack / unpack
tfeher Apr 5, 2023
09abadb
Merge branch 'branch-23.06' into fea-ivf-pq-reconstruct
cjnolet Apr 12, 2023
923c3b5
Merge branch 'branch-23.06' into fea-ivf-pq-reconstruct
cjnolet Apr 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Implement unpack_list_data
achirkin committed Mar 16, 2023
commit 3a2c62235861e1e6aa7a1c64c4fa9e1e4036ae99
209 changes: 145 additions & 64 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
@@ -557,6 +557,150 @@ void train_per_cluster(raft::device_resources const& handle,
transpose_pq_centers(handle, index, pq_centers_tmp.data());
}

/**
* Process a single vector in a list.
*
* @tparam PqBits
* @tparam Action tells how to process a single vectors (e.g. reconstruct or just unpack)
*
* @param[in] in_list_data the encoded cluster data.
* @param[in] in_ix in-cluster index of the vector to be decoded (one-per-thread).
* @param[in] out_ix the output index passed to the action
* @param[in] pq_dim
* @param action a callable action to be invoked on each PQ code (component of the encoding)
* type: void (uint8_t code, uint32_t out_ix, uint32_t j), where j = [0..pq_dim).
*/
template <uint32_t PqBits, typename Action>
__device__ void run_on_vector(
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> in_list_data,
uint32_t in_ix,
uint32_t out_ix,
uint32_t pq_dim,
Action action)
{
using group_align = Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(in_ix);
const uint32_t ingroup_ix = group_align::mod(in_ix);

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;
for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// read the chunk
code_chunk = *reinterpret_cast<const pq_vec_t*>(&in_list_data(group_ix, i, ingroup_ix, 0));
// read the codes, one/pq_dim at a time
#pragma unroll
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
// read a piece of the reconstructed vector
action(code_view[k], out_ix, j);
}
}
}

/** Process the given indices or a block of a single list (cluster). */
template <uint32_t PqBits, typename Action>
__device__ void run_on_list(device_vector_view<const uint8_t* const, uint32_t, row_major> data_ptrs,
device_vector_view<const uint32_t, uint32_t, row_major> list_sizes,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t len,
uint32_t cluster_ix,
uint32_t pq_dim,
Action action)
{
auto pq_extents =
list_spec<uint32_t, uint32_t>{PqBits, pq_dim, true}.make_list_extents(list_sizes[cluster_ix]);
auto pq_dataset =
make_mdspan<const uint8_t, uint32_t, row_major, false, true>(data_ptrs[cluster_ix], pq_extents);

for (uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; ix < len; ix += blockDim.x) {
const uint32_t src_ix = std::holds_alternative<uint32_t>(offset_or_indices)
? std::get<uint32_t>(offset_or_indices) + ix
: std::get<const uint32_t*>(offset_or_indices)[ix];
run_on_vector<PqBits>(pq_dataset, src_ix, ix, pq_dim, action);
}
}

/**
* A consumer for the `run_on_list` and `run_on_vec` that just flattens PQ codes
* one-per-byte. That is, independent of the code width (pq_bits), one code uses
* the whole byte, hence one vectors uses pq_dim bytes.
*/
struct unpack_codes {
device_matrix_view<uint8_t, uint32_t, row_major> out_codes;

/**
* Create a callable to be passed to `run_on_list`.
*
* @param[out] out_codes the destination for the read codes.
*/
__device__ inline unpack_codes(device_matrix_view<uint8_t, uint32_t, row_major> out_codes)
: out_codes{out_codes}
{
}

/** Write j-th component (code) of the i-th vector into the output array. */
__device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
out_codes(i, j) = code;
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) __global__ void unpack_list_data_kernel(
device_matrix_view<uint8_t, uint32_t, row_major> out_codes,
device_vector_view<const uint8_t* const, uint32_t, row_major> data_ptrs,
device_vector_view<const uint32_t, uint32_t, row_major> list_sizes,
uint32_t cluster_ix,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
const uint32_t pq_dim = out_codes.extent(1);
auto unpack_action = unpack_codes{out_codes};
run_on_list<PqBits>(data_ptrs,
list_sizes,
offset_or_indices,
out_codes.extent(0),
cluster_ix,
pq_dim,
unpack_action);
}

/** Decode the list data; see the public interface for the api and usage. */
template <typename IdxT>
void unpack_list_data(raft::device_resources const& res,
const index<IdxT>& index,
device_matrix_view<uint8_t, uint32_t, row_major> out_codes,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
auto n_rows = out_codes.extent(0);
if (n_rows == 0) { return; }
if (std::holds_alternative<uint32_t>(offset_or_indices)) {
auto n_skip = std::get<uint32_t>(offset_or_indices);
// sic! I'm using the upper bound `list.size` instead of exact `list_sizes(label)`
// to avoid an extra device-host data copy and the stream sync.
RAFT_EXPECTS(n_skip + n_rows <= index.lists()[label]->size.load(),
"offset + output size must be not bigger than the cluster size.");
}

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [](uint32_t pq_bits) {
switch (pq_bits) {
case 4: return unpack_list_data_kernel<kBlockSize, 4>;
case 5: return unpack_list_data_kernel<kBlockSize, 5>;
case 6: return unpack_list_data_kernel<kBlockSize, 6>;
case 7: return unpack_list_data_kernel<kBlockSize, 7>;
case 8: return unpack_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}(index.pq_bits());
kernel<<<blocks, threads, 0, res.get_stream()>>>(
out_codes, index.data_ptrs(), index.list_sizes(), label, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/** A consumer for the `run_on_list` and `run_on_vec` that approximates the original input data. */
struct reconstruct_vectors {
codebook_gen codebook_kind;
uint32_t cluster_ix;
@@ -566,7 +710,7 @@ struct reconstruct_vectors {
device_mdspan<float, extent_3d<uint32_t>, row_major> out_vectors;

/**
* Create the functor to be passed to `run_on_list`.
* Create a callable to be passed to `run_on_list`.
*
* @param[out] out_vectors the destination for the decoded vectors.
* @param[in] pq_centers the codebook
@@ -626,69 +770,6 @@ struct reconstruct_vectors {
}
};

/**
* Process a single vector in a list.
*
* @tparam PqBits
* @tparam Action tells how to process a single vectors (e.g. reconstruct or just unpack)
*
* @param[in] in_list_data the encoded cluster data.
* @param[in] in_ix in-cluster index of the vector to be decoded (one-per-thread).
* @param[in] out_ix the output index passed to the action
* @param[in] pq_dim
* @param action a callable action to be invoked on each PQ code (component of the encoding)
* type: void (uint8_t code, uint32_t out_ix, uint32_t j), where j = [0..pq_dim).
*/
template <uint32_t PqBits, typename Action>
__device__ void run_on_vector(
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> in_list_data,
uint32_t in_ix,
uint32_t out_ix,
uint32_t pq_dim,
Action action)
{
using group_align = Pow2<kIndexGroupSize>;
const uint32_t group_ix = group_align::div(in_ix);
const uint32_t ingroup_ix = group_align::mod(in_ix);

pq_vec_t code_chunk;
bitfield_view_t<PqBits> code_view{reinterpret_cast<uint8_t*>(&code_chunk)};
constexpr uint32_t kChunkSize = (sizeof(pq_vec_t) * 8u) / PqBits;
for (uint32_t j = 0, i = 0; j < pq_dim; i++) {
// read the chunk
code_chunk = *reinterpret_cast<const pq_vec_t*>(&in_list_data(group_ix, i, ingroup_ix, 0));
// read the codes, one/pq_dim at a time
#pragma unroll
for (uint32_t k = 0; k < kChunkSize && j < pq_dim; k++, j++) {
// read a piece of the reconstructed vector
action(code_view[k], out_ix, j);
}
}
}

/** Process the given indices or a block of a single list (cluster). */
template <uint32_t PqBits, typename Action>
__device__ void run_on_list(device_vector_view<const uint8_t* const, uint32_t, row_major> data_ptrs,
device_vector_view<const uint32_t, uint32_t, row_major> list_sizes,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t len,
uint32_t cluster_ix,
uint32_t pq_dim,
Action action)
{
auto pq_extents =
list_spec<uint32_t, uint32_t>{PqBits, pq_dim, true}.make_list_extents(list_sizes[cluster_ix]);
auto pq_dataset =
make_mdspan<const uint8_t, uint32_t, row_major, false, true>(data_ptrs[cluster_ix], pq_extents);

for (uint32_t ix = threadIdx.x + blockDim.x * blockIdx.x; ix < len; ix += blockDim.x) {
const uint32_t src_ix = std::holds_alternative<uint32_t>(offset_or_indices)
? std::get<uint32_t>(offset_or_indices) + ix
: std::get<const uint32_t*>(offset_or_indices)[ix];
run_on_vector<PqBits>(pq_dataset, src_ix, ix, pq_dim, action);
}
}

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) __global__ void reconstruct_list_data_kernel(
device_matrix_view<float, uint32_t, row_major> out_vectors,