Skip to content

Commit

Permalink
[FEA] Helpers and CodePacker for IVF-PQ (#1826)
Browse files Browse the repository at this point in the history
- [x] Codepacking for compressed on-device (flat) PQ codes
- [x] Testing

Authors:
  - Tarang Jain (https://github.com/tarang-jain)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1826
  • Loading branch information
tarang-jain authored Nov 17, 2023
1 parent 6e64e0f commit 52a5e4c
Show file tree
Hide file tree
Showing 7 changed files with 866 additions and 84 deletions.
66 changes: 66 additions & 0 deletions cpp/include/raft/neighbors/detail/div_utils.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifdef _RAFT_HAS_CUDA
#include <raft/util/pow2_utils.cuh>
#else
#include <raft/util/integer_utils.hpp>
#endif

/**
* @brief A simple wrapper for raft::Pow2 which uses Pow2 utils only when available and regular
* integer division otherwise. This is done to allow a common interface for division arithmetic for
* non CUDA headers.
*
* @tparam Value_ a compile-time value representable as a power-of-two.
*/
namespace raft::neighbors::detail {
template <auto Value_>
struct div_utils {
typedef decltype(Value_) Type;
static constexpr Type Value = Value_;

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto roundDown(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::roundDown(x);
#else
return raft::round_down_safe(x, Value_);
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto mod(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::mod(x);
#else
return x % Value_;
#endif
}

template <typename T>
static constexpr _RAFT_HOST_DEVICE inline auto div(T x)
{
#if defined(_RAFT_HAS_CUDA)
return Pow2<Value_>::div(x);
#else
return x / Value_;
#endif
}
};
} // namespace raft::neighbors::detail
290 changes: 243 additions & 47 deletions cpp/include/raft/neighbors/detail/ivf_pq_build.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,59 @@ auto calculate_offsets_and_indices(IdxT n_rows,
return max_cluster_size;
}

template <typename IdxT>
void set_centers(raft::resources const& handle, index<IdxT>* index, const float* cluster_centers)
{
auto stream = resource::get_cuda_stream(handle);
auto* device_memory = resource::get_workspace_resource(handle);

// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle(),
sizeof(float) * index->dim_ext(),
cluster_centers,
sizeof(float) * index->dim(),
sizeof(float) * index->dim(),
index->n_lists(),
cudaMemcpyDefault,
stream));

rmm::device_uvector<float> center_norms(index->n_lists(), stream, device_memory);
raft::linalg::rowNorm(center_norms.data(),
cluster_centers,
index->dim(),
index->n_lists(),
raft::linalg::L2Norm,
true,
stream);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index->centers().data_handle() + index->dim(),
sizeof(float) * index->dim_ext(),
center_norms.data(),
sizeof(float),
sizeof(float),
index->n_lists(),
cudaMemcpyDefault,
stream));

// Rotate cluster_centers
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(handle,
true,
false,
index->rot_dim(),
index->n_lists(),
index->dim(),
&alpha,
index->rotation_matrix().data_handle(),
index->dim(),
cluster_centers,
index->dim(),
&beta,
index->centers_rot().data_handle(),
index->rot_dim(),
resource::get_cuda_stream(handle));
}

template <typename IdxT>
void transpose_pq_centers(const resources& handle,
index<IdxT>& index,
Expand Down Expand Up @@ -613,6 +666,100 @@ void unpack_list_data(raft::resources const& res,
resource::get_cuda_stream(res));
}

/**
* A consumer for the `run_on_vector` that just flattens PQ codes
* into a tightly packed matrix. That is, the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct unpack_contiguous {
uint8_t* codes;
uint32_t code_size;

/**
* Create a callable to be passed to `run_on_vector`.
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline unpack_contiguous(uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}

/** Write j-th component (code) of the i-th vector into the output array. */
__host__ __device__ inline void operator()(uint8_t code, uint32_t i, uint32_t j)
{
bitfield_view_t<PqBits> code_view{codes + i * code_size};
code_view[j] = code;
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL unpack_contiguous_list_data_kernel(
uint8_t* out_codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> in_list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
run_on_list<PqBits>(
in_list_data, offset_or_indices, n_rows, pq_dim, unpack_contiguous<PqBits>(out_codes, pq_dim));
}

/**
* Unpack flat PQ codes from an existing list by the given offset.
*
* @param[out] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] list_data the packed ivf::list data.
* @param[in] offset_or_indices how many records in the list to skip or the exact indices.
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void unpack_contiguous_list_data(
uint8_t* codes,
device_mdspan<const uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)
{
if (n_rows == 0) { return; }

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return unpack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return unpack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return unpack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return unpack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return unpack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
kernel<<<blocks, threads, 0, stream>>>(codes, list_data, n_rows, pq_dim, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

/** Unpack the list data; see the public interface for the api and usage. */
template <typename IdxT>
void unpack_contiguous_list_data(raft::resources const& res,
const index<IdxT>& index,
uint8_t* out_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
unpack_contiguous_list_data(out_codes,
index.lists()[label]->data.view(),
n_rows,
index.pq_dim(),
offset_or_indices,
index.pq_bits(),
resource::get_cuda_stream(res));
}

/** A consumer for the `run_on_list` and `run_on_vector` that approximates the original input data.
*/
struct reconstruct_vectors {
Expand Down Expand Up @@ -850,6 +997,101 @@ void pack_list_data(raft::resources const& res,
resource::get_cuda_stream(res));
}

/**
* A producer for the `write_vector` reads tightly packed flat codes. That is,
* the codes are not expanded to one code-per-byte.
*/
template <uint32_t PqBits>
struct pack_contiguous {
const uint8_t* codes;
uint32_t code_size;

/**
* Create a callable to be passed to `write_vector`.
*
* @param[in] codes flat compressed PQ codes
*/
__host__ __device__ inline pack_contiguous(const uint8_t* codes, uint32_t pq_dim)
: codes{codes}, code_size{raft::ceildiv<uint32_t>(pq_dim * PqBits, 8)}
{
}

/** Read j-th component (code) of the i-th vector from the source. */
__host__ __device__ inline auto operator()(uint32_t i, uint32_t j) -> uint8_t
{
bitfield_view_t<PqBits> code_view{const_cast<uint8_t*>(codes + i * code_size)};
return uint8_t(code_view[j]);
}
};

template <uint32_t BlockSize, uint32_t PqBits>
__launch_bounds__(BlockSize) RAFT_KERNEL pack_contiguous_list_data_kernel(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
write_list<PqBits, 1>(
list_data, offset_or_indices, n_rows, pq_dim, pack_contiguous<PqBits>(codes, pq_dim));
}

/**
* Write flat PQ codes into an existing list by the given offset.
*
* NB: no memory allocation happens here; the list must fit the data (offset + n_rows).
*
* @param[out] list_data the packed ivf::list data.
* @param[in] codes flat compressed PQ codes [n_rows, ceildiv(pq_dim * pq_bits, 8)]
* @param[in] offset_or_indices how many records in the list to skip or the exact indices.
* @param[in] pq_bits codebook size (1 << pq_bits)
* @param[in] stream
*/
inline void pack_contiguous_list_data(
device_mdspan<uint8_t, list_spec<uint32_t, uint32_t>::list_extents, row_major> list_data,
const uint8_t* codes,
uint32_t n_rows,
uint32_t pq_dim,
std::variant<uint32_t, const uint32_t*> offset_or_indices,
uint32_t pq_bits,
rmm::cuda_stream_view stream)
{
if (n_rows == 0) { return; }

constexpr uint32_t kBlockSize = 256;
dim3 blocks(div_rounding_up_safe<uint32_t>(n_rows, kBlockSize), 1, 1);
dim3 threads(kBlockSize, 1, 1);
auto kernel = [pq_bits]() {
switch (pq_bits) {
case 4: return pack_contiguous_list_data_kernel<kBlockSize, 4>;
case 5: return pack_contiguous_list_data_kernel<kBlockSize, 5>;
case 6: return pack_contiguous_list_data_kernel<kBlockSize, 6>;
case 7: return pack_contiguous_list_data_kernel<kBlockSize, 7>;
case 8: return pack_contiguous_list_data_kernel<kBlockSize, 8>;
default: RAFT_FAIL("Invalid pq_bits (%u), the value must be within [4, 8]", pq_bits);
}
}();
kernel<<<blocks, threads, 0, stream>>>(list_data, codes, n_rows, pq_dim, offset_or_indices);
RAFT_CUDA_TRY(cudaPeekAtLastError());
}

template <typename IdxT>
void pack_contiguous_list_data(raft::resources const& res,
index<IdxT>* index,
const uint8_t* new_codes,
uint32_t n_rows,
uint32_t label,
std::variant<uint32_t, const uint32_t*> offset_or_indices)
{
pack_contiguous_list_data(index->lists()[label]->data.view(),
new_codes,
n_rows,
index->pq_dim(),
offset_or_indices,
index->pq_bits(),
resource::get_cuda_stream(res));
}

/**
*
* A producer for the `write_list` and `write_vector` that encodes level-1 input vector residuals
Expand Down Expand Up @@ -1634,60 +1876,14 @@ auto build(raft::resources const& handle,
labels_view,
utils::mapping<float>());

{
// combine cluster_centers and their norms
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle(),
sizeof(float) * index.dim_ext(),
cluster_centers,
sizeof(float) * index.dim(),
sizeof(float) * index.dim(),
index.n_lists(),
cudaMemcpyDefault,
stream));

rmm::device_uvector<float> center_norms(index.n_lists(), stream, device_memory);
raft::linalg::rowNorm(center_norms.data(),
cluster_centers,
index.dim(),
index.n_lists(),
raft::linalg::L2Norm,
true,
stream);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(index.centers().data_handle() + index.dim(),
sizeof(float) * index.dim_ext(),
center_norms.data(),
sizeof(float),
sizeof(float),
index.n_lists(),
cudaMemcpyDefault,
stream));
}

// Make rotation matrix
make_rotation_matrix(handle,
params.force_random_rotation,
index.rot_dim(),
index.dim(),
index.rotation_matrix().data_handle());

// Rotate cluster_centers
float alpha = 1.0;
float beta = 0.0;
linalg::gemm(handle,
true,
false,
index.rot_dim(),
index.n_lists(),
index.dim(),
&alpha,
index.rotation_matrix().data_handle(),
index.dim(),
cluster_centers,
index.dim(),
&beta,
index.centers_rot().data_handle(),
index.rot_dim(),
stream);
set_centers(handle, &index, cluster_centers);

// Train PQ codebooks
switch (index.codebook_kind()) {
Expand Down
Loading

0 comments on commit 52a5e4c

Please sign in to comment.