Skip to content

Commit

Permalink
Minor follow-up fixes for ivf-flat (#796)
Browse files Browse the repository at this point in the history
Small fixes to ivf-flat and its dependency warp_sort:

- fix the template parameter for the call of `calc_smem_size_for_block_wide` (may reduce the use of shared memory);
- force initialize warp_sort internal value buffers to avoid uninitialized output in case of very small input data size;
- small readability fixes.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #796
  • Loading branch information
achirkin authored Sep 2, 2022
1 parent ff133d4 commit 8b97c1b
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 8 deletions.
9 changes: 4 additions & 5 deletions cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -696,8 +696,8 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
copy_vectorized(query_shared, query, std::min(dim, query_smem_elems));
__syncthreads();

topk::block_sort<topk::warp_sort_filtered, Capacity, Ascending, float, IdxT> queue(
k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T));
using block_sort_t = topk::block_sort<topk::warp_sort_filtered, Capacity, Ascending, float, IdxT>;
block_sort_t queue(k, interleaved_scan_kernel_smem + query_smem_elems * sizeof(T));

{
using align_warp = Pow2<WarpSize>;
Expand Down Expand Up @@ -766,8 +766,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock)
}

// Enqueue one element per thread
constexpr float kDummy = Ascending ? upper_bound<float>() : lower_bound<float>();
const float val = valid ? static_cast<float>(dist) : kDummy;
const float val = valid ? static_cast<float>(dist) : block_sort_t::queue_t::kDummy;
const size_t idx = valid ? static_cast<size_t>(list_indices[list_offset + vec_id]) : 0;
queue.add(val, idx);
}
Expand Down Expand Up @@ -826,7 +825,7 @@ void launch_kernel(Lambda lambda,
std::min<int>(max_query_smem / sizeof(T), Pow2<Veclen * WarpSize>::roundUp(index.dim()));
int smem_size = query_smem_elems * sizeof(T);
constexpr int kSubwarpSize = std::min<int>(Capacity, WarpSize);
smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide<AccT, size_t>(
smem_size += raft::spatial::knn::detail::topk::calc_smem_size_for_block_wide<AccT, IdxT>(
kThreadsPerBlock / kSubwarpSize, k);

// power-of-two less than cuda limit (for better addr alignment)
Expand Down
6 changes: 5 additions & 1 deletion cpp/include/raft/spatial/knn/detail/topk/warpsort_topk.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ constexpr auto calc_capacity(int k) -> int
template <int Capacity, bool Ascending, typename T, typename IdxT>
class warp_sort {
static_assert(isPo2(Capacity));
static_assert(std::is_default_constructible_v<IdxT>);

public:
/**
Expand All @@ -158,6 +159,7 @@ class warp_sort {
#pragma unroll
for (int i = 0; i < kMaxArrLen; i++) {
val_arr_[i] = kDummy;
idx_arr_[i] = IdxT{};
}
}

Expand Down Expand Up @@ -280,6 +282,7 @@ class warp_sort_filtered : public warp_sort<Capacity, Ascending, T, IdxT> {
#pragma unroll
for (int i = 0; i < kMaxBufLen; i++) {
val_buf_[i] = kDummy;
idx_buf_[i] = IdxT{};
}
}

Expand Down Expand Up @@ -371,6 +374,7 @@ class warp_sort_immediate : public warp_sort<Capacity, Ascending, T, IdxT> {
#pragma unroll
for (int i = 0; i < kMaxArrLen; i++) {
val_buf_[i] = kDummy;
idx_buf_[i] = IdxT{};
}
}

Expand Down Expand Up @@ -429,9 +433,9 @@ template <template <int, bool, typename, typename> class WarpSortWarpWide,
typename T,
typename IdxT>
class block_sort {
public:
using queue_t = WarpSortWarpWide<Capacity, Ascending, T, IdxT>;

public:
__device__ block_sort(int k, uint8_t* smem_buf) : queue_(k)
{
val_smem_ = reinterpret_cast<T*>(smem_buf);
Expand Down
7 changes: 5 additions & 2 deletions cpp/test/spatial/ann_ivf_flat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include <raft/core/logger.hpp>
#include <raft/distance/distance_type.hpp>
#include <raft/random/rng.cuh>
#include <raft/sparse/detail/utils.h>
#include <raft/spatial/knn/ann.cuh>
#include <raft/spatial/knn/ivf_flat.cuh>
#include <raft/spatial/knn/knn.cuh>
Expand All @@ -30,6 +29,8 @@

#include <gtest/gtest.h>

#include <thrust/sequence.h>

#include <cstddef>
#include <iostream>
#include <vector>
Expand Down Expand Up @@ -209,7 +210,9 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs> {
ivf_flat::build(handle_, index_params, database.data(), int64_t(ps.num_db_vecs), ps.dim);

rmm::device_uvector<int64_t> vector_indices(ps.num_db_vecs, stream_);
sparse::iota_fill(vector_indices.data(), int64_t(ps.num_db_vecs), int64_t(1), stream_);
thrust::sequence(handle_.get_thrust_policy(),
thrust::device_pointer_cast(vector_indices.data()),
thrust::device_pointer_cast(vector_indices.data() + ps.num_db_vecs));
handle_.sync_stream(stream_);

int64_t half_of_data = ps.num_db_vecs / 2;
Expand Down

0 comments on commit 8b97c1b

Please sign in to comment.