Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAGRA - separable compilation for distance computation #296

Merged
Merged
Show file tree
Hide file tree
Changes from 30 commits
Commits
Show all changes
56 commits
Select commit Hold shift + click to select a range
0dbe5b2
[WIP] CAGRA - separable compilation for distance computation
achirkin Aug 16, 2024
93b0439
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Aug 16, 2024
ba52b13
Fix style
achirkin Aug 16, 2024
434e50a
Add missing multi-kernel implementation
achirkin Aug 19, 2024
6352550
Move common code out of virtual functions scope (aiming for more inli…
achirkin Aug 19, 2024
d161f79
Make small descriptor functions into fields
achirkin Aug 20, 2024
35c3813
Minor updates to improve reg count
achirkin Aug 20, 2024
4b5dcd3
Refactor distance_core -> compute_distance, and update the instance list
achirkin Aug 21, 2024
e5878db
Merge remote-tracking branch 'rapidsai/branch-24.10' into enh-cagra-s…
achirkin Aug 21, 2024
385a8c4
Make the compute_distance instances controlled from a single place
achirkin Aug 21, 2024
3f77cda
Refactor usage of init_kernel to make sure it instantiated in the sam…
achirkin Aug 22, 2024
ddb0488
Reduce the register usage in distance functions
achirkin Aug 22, 2024
c244ead
Partially implemented manual dispatch
achirkin Aug 23, 2024
7eb6a27
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Aug 23, 2024
ff2fdbe
Finish manual dispatch
achirkin Aug 23, 2024
78a9809
Change instance generator to have blockdim/team_size ratio 16
achirkin Aug 23, 2024
6082bf7
Trying various minor things to reduce register spilling
achirkin Aug 23, 2024
fc7d832
Move the metric parameter to the compute_distance template
achirkin Aug 26, 2024
118808e
Further reduce register pressure by moving code out of the non-inlina…
achirkin Aug 27, 2024
abec125
Manually unroll device::team_sum
achirkin Aug 27, 2024
cf0101c
Remove the test of a compute_distance instance that is not compiled (…
achirkin Aug 28, 2024
b3e6d26
Hide previously not hidden kernels
achirkin Aug 28, 2024
f231828
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Aug 28, 2024
dc75f7a
Reduce register usage by minimizing the part of descriptor struct pas…
achirkin Sep 2, 2024
6630a99
Further reduce the size size of the dataset descriptor and add explic…
achirkin Sep 2, 2024
790e79c
Cache dataset descriptors to recover small batch performance
achirkin Sep 2, 2024
7599331
Reduce the register usage in compute_distance_standard further
achirkin Sep 3, 2024
4d9241e
Reduce the generated code volume
achirkin Sep 3, 2024
5fdcdd0
More explicit ldg cache behavior and a few smaller things
achirkin Sep 4, 2024
5984596
Simplify vpq indexing arithmetics a bit
achirkin Sep 4, 2024
af0cc12
Bring back the fatbin.ld link option
achirkin Sep 5, 2024
9023e68
relax the config for checking the raft_cutlass symbol exclusion (see …
achirkin Sep 5, 2024
99d2bd3
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 6, 2024
75a2dac
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 9, 2024
6a1b898
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 10, 2024
c1eed0e
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 10, 2024
d4673cf
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 11, 2024
0046a73
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 11, 2024
267902e
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 16, 2024
c0f5715
Add pointer hints and reduce the instruction count a bit
achirkin Sep 18, 2024
f65cfd7
Reorganize the compute-similarity code to allow the compiler optimize…
achirkin Sep 18, 2024
0504129
Disable swizzling and reduce the instruction count in VPQ distance
achirkin Sep 18, 2024
27f6581
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 18, 2024
b605061
Don't apply swizzling when the bank conflicts are not possible (small…
achirkin Sep 19, 2024
478a824
Minor improvements to multi-cta kernel
achirkin Sep 20, 2024
5090ebb
Transpose query buffer instead of swizzling in VPQ distance to reduce…
achirkin Sep 20, 2024
9f069af
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 20, 2024
6fac19b
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 23, 2024
d0eb9b3
VPQ distance: don't pass n_subspace as parameter, because it can be c…
achirkin Sep 23, 2024
7bce6da
Docs and readability: device_common.hpp and factory.cuh
achirkin Sep 23, 2024
5154892
Remove unused distance instances (with uint64_t index type)
achirkin Sep 23, 2024
a0c54e3
compute_distance.hpp: document and slightly simplify the dataset desc…
achirkin Sep 23, 2024
9ba3e3f
Document the dataset/distance descriptor selection logic
achirkin Sep 23, 2024
f77c1b0
Remove commented-out code sections
achirkin Sep 23, 2024
eabb3ae
Merge branch 'branch-24.10' into enh-cagra-separable-compilation
achirkin Sep 24, 2024
f1426cf
Remove empty comment
achirkin Sep 25, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
220 changes: 119 additions & 101 deletions cpp/CMakeLists.txt

Large diffs are not rendered by default.

41 changes: 34 additions & 7 deletions cpp/include/cuvs/neighbors/common.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,22 @@ struct owning_dataset : public strided_dataset<DataT, IdxT> {
};
};

template <typename DatasetT>
struct is_strided_dataset : std::false_type {};

template <typename DataT, typename IdxT>
struct is_strided_dataset<strided_dataset<DataT, IdxT>> : std::true_type {};

template <typename DataT, typename IdxT>
struct is_strided_dataset<non_owning_dataset<DataT, IdxT>> : std::true_type {};

template <typename DataT, typename IdxT, typename LayoutPolicy, typename ContainerPolicy>
struct is_strided_dataset<owning_dataset<DataT, IdxT, LayoutPolicy, ContainerPolicy>>
: std::true_type {};

template <typename DatasetT>
inline constexpr bool is_strided_dataset_v = is_strided_dataset<DatasetT>::value;

/**
* @brief Contstruct a strided matrix from any mdarray or mdspan.
*
Expand Down Expand Up @@ -284,23 +300,25 @@ auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t
*/
template <typename MathT, typename IdxT>
struct vpq_dataset : public dataset<IdxT> {
using index_type = IdxT;
using math_type = MathT;
/** Vector Quantization codebook - "coarse cluster centers". */
raft::device_matrix<MathT, uint32_t, raft::row_major> vq_code_book;
raft::device_matrix<math_type, uint32_t, raft::row_major> vq_code_book;
/** Product Quantization codebook - "fine cluster centers". */
raft::device_matrix<MathT, uint32_t, raft::row_major> pq_code_book;
raft::device_matrix<math_type, uint32_t, raft::row_major> pq_code_book;
/** Compressed dataset. */
raft::device_matrix<uint8_t, IdxT, raft::row_major> data;
raft::device_matrix<uint8_t, index_type, raft::row_major> data;

vpq_dataset(raft::device_matrix<MathT, uint32_t, raft::row_major>&& vq_code_book,
raft::device_matrix<MathT, uint32_t, raft::row_major>&& pq_code_book,
raft::device_matrix<uint8_t, IdxT, raft::row_major>&& data)
vpq_dataset(raft::device_matrix<math_type, uint32_t, raft::row_major>&& vq_code_book,
raft::device_matrix<math_type, uint32_t, raft::row_major>&& pq_code_book,
raft::device_matrix<uint8_t, index_type, raft::row_major>&& data)
: vq_code_book{std::move(vq_code_book)},
pq_code_book{std::move(pq_code_book)},
data{std::move(data)}
{
}

[[nodiscard]] auto n_rows() const noexcept -> IdxT final { return data.extent(0); }
[[nodiscard]] auto n_rows() const noexcept -> index_type final { return data.extent(0); }
[[nodiscard]] auto dim() const noexcept -> uint32_t final { return vq_code_book.extent(1); }
[[nodiscard]] auto is_owning() const noexcept -> bool final { return true; }

Expand Down Expand Up @@ -354,6 +372,15 @@ struct vpq_dataset : public dataset<IdxT> {
}
};

template <typename DatasetT>
struct is_vpq_dataset : std::false_type {};

template <typename MathT, typename IdxT>
struct is_vpq_dataset<vpq_dataset<MathT, IdxT>> : std::true_type {};

template <typename DatasetT>
inline constexpr bool is_vpq_dataset_v = is_vpq_dataset<DatasetT>::value;

namespace filtering {

/* A filter that filters nothing. This is the default behavior. */
Expand Down
16 changes: 8 additions & 8 deletions cpp/src/neighbors/detail/ann_utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream)
}

template <typename T, typename IdxT>
RAFT_KERNEL outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c)
static __global__ void outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c)
{
IdxT gid = threadIdx.x + blockDim.x * static_cast<IdxT>(blockIdx.x);
IdxT i = gid / len_b;
Expand All @@ -234,12 +234,12 @@ RAFT_KERNEL outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T*
}

template <typename T, typename IdxT>
RAFT_KERNEL block_copy_kernel(const IdxT* in_offsets,
const IdxT* out_offsets,
IdxT n_blocks,
const T* in_data,
T* out_data,
IdxT n_mult)
static __global__ void block_copy_kernel(const IdxT* in_offsets,
const IdxT* out_offsets,
IdxT n_blocks,
const T* in_data,
T* out_data,
IdxT n_mult)
{
IdxT i = static_cast<IdxT>(blockDim.x) * static_cast<IdxT>(blockIdx.x) + threadIdx.x;
// find the source offset using the binary search.
Expand Down Expand Up @@ -317,7 +317,7 @@ void outer_add(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c, rmm::cuda_s
}

template <typename T, typename S, typename IdxT, typename LabelT>
RAFT_KERNEL copy_selected_kernel(
static __global__ void copy_selected_kernel(
IdxT n_rows, IdxT n_cols, const S* src, const LabelT* row_ids, IdxT ld_src, T* dst, IdxT ld_dst)
{
IdxT gid = threadIdx.x + blockDim.x * static_cast<IdxT>(blockIdx.x);
Expand Down
37 changes: 28 additions & 9 deletions cpp/src/neighbors/detail/cagra/bitonic.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace bitonic {
namespace detail {

template <class K, class V>
_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, K& k1, V& v1, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void swap_if_needed(K& k0, V& v0, K& k1, V& v1, const bool asc)
{
if ((k0 != k1) && ((k0 < k1) != asc)) {
const auto tmp_k = k0;
Expand All @@ -39,7 +39,10 @@ _RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, K& k1, V& v1, const bool a
}

template <class K, class V>
_RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, const unsigned lane_offset, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void swap_if_needed(K& k0,
V& v0,
const unsigned lane_offset,
const bool asc)
{
auto k1 = __shfl_xor_sync(~0u, k0, lane_offset);
auto v1 = __shfl_xor_sync(~0u, v0, lane_offset);
Expand All @@ -51,7 +54,10 @@ _RAFT_DEVICE inline void swap_if_needed(K& k0, V& v0, const unsigned lane_offset

template <class K, class V, unsigned N, unsigned warp_size = 32>
struct warp_merge_core {
_RAFT_DEVICE inline void operator()(K k[N], V v[N], const std::uint32_t range, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[N],
V v[N],
const std::uint32_t range,
const bool asc)
{
const auto lane_id = threadIdx.x % warp_size;

Expand Down Expand Up @@ -93,7 +99,10 @@ struct warp_merge_core {

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 6, warp_size> {
_RAFT_DEVICE inline void operator()(K k[6], V v[6], const std::uint32_t range, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[6],
V v[6],
const std::uint32_t range,
const bool asc)
{
constexpr unsigned N = 6;
const auto lane_id = threadIdx.x % warp_size;
Expand Down Expand Up @@ -141,7 +150,10 @@ struct warp_merge_core<K, V, 6, warp_size> {

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 3, warp_size> {
_RAFT_DEVICE inline void operator()(K k[3], V v[3], const std::uint32_t range, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[3],
V v[3],
const std::uint32_t range,
const bool asc)
{
constexpr unsigned N = 3;
const auto lane_id = threadIdx.x % warp_size;
Expand Down Expand Up @@ -171,7 +183,10 @@ struct warp_merge_core<K, V, 3, warp_size> {

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 2, warp_size> {
_RAFT_DEVICE inline void operator()(K k[2], V v[2], const std::uint32_t range, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[2],
V v[2],
const std::uint32_t range,
const bool asc)
{
constexpr unsigned N = 2;
const auto lane_id = threadIdx.x % warp_size;
Expand All @@ -197,7 +212,10 @@ struct warp_merge_core<K, V, 2, warp_size> {

template <class K, class V, unsigned warp_size>
struct warp_merge_core<K, V, 1, warp_size> {
_RAFT_DEVICE inline void operator()(K k[1], V v[1], const std::uint32_t range, const bool asc)
RAFT_DEVICE_INLINE_FUNCTION void operator()(K k[1],
V v[1],
const std::uint32_t range,
const bool asc)
{
const auto lane_id = threadIdx.x % warp_size;
const std::uint32_t b = range;
Expand All @@ -211,14 +229,15 @@ struct warp_merge_core<K, V, 1, warp_size> {
} // namespace detail

template <class K, class V, unsigned N, unsigned warp_size = 32>
__device__ void warp_merge(K k[N], V v[N], unsigned range, const bool asc = true)
RAFT_DEVICE_INLINE_FUNCTION void warp_merge(K k[N], V v[N], unsigned range, const bool asc = true)
{
detail::warp_merge_core<K, V, N, warp_size>{}(k, v, range, asc);
}

template <class K, class V, unsigned N, unsigned warp_size = 32>
__device__ void warp_sort(K k[N], V v[N], const bool asc = true)
RAFT_DEVICE_INLINE_FUNCTION void warp_sort(K k[N], V v[N], const bool asc = true)
{
#pragma unroll
for (std::uint32_t range = 1; range <= warp_size; range <<= 1) {
warp_merge<K, V, N, warp_size>(k, v, range, asc);
}
Expand Down
Loading
Loading