diff --git a/cpp/bench/spatial/knn.cuh b/cpp/bench/spatial/knn.cuh index 921932c791..bb01320cdf 100644 --- a/cpp/bench/spatial/knn.cuh +++ b/cpp/bench/spatial/knn.cuh @@ -30,6 +30,9 @@ #if defined RAFT_NN_COMPILED #include +#if defined RAFT_DISTANCE_COMPILED +#include +#endif #endif #include diff --git a/cpp/include/raft/distance/detail/fused_l2_nn.cuh b/cpp/include/raft/distance/detail/fused_l2_nn.cuh index 6bd622853d..8aae7d40f4 100644 --- a/cpp/include/raft/distance/detail/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/detail/fused_l2_nn.cuh @@ -59,7 +59,7 @@ struct MinAndDistanceReduceOpImpl { DI void init(DataT* out, DataT maxVal) { *out = maxVal; } DI void init(KVP* out, DataT maxVal) { - out->key = -1; + out->key = 0; out->value = maxVal; } }; @@ -150,7 +150,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, KVPair val[P::AccRowsPerTh]; #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; + val[i] = {0, maxVal}; } // epilogue operation lambda for final value calculation @@ -222,7 +222,7 @@ __global__ __launch_bounds__(P::Nthreads, 2) void fusedL2NNkernel(OutT* min, // reset the val array. #pragma unroll for (int i = 0; i < P::AccRowsPerTh; ++i) { - val[i] = {-1, maxVal}; + val[i] = {0, maxVal}; } }; diff --git a/cpp/include/raft/distance/fused_l2_nn.cuh b/cpp/include/raft/distance/fused_l2_nn.cuh index ed781f1d18..2915bce360 100644 --- a/cpp/include/raft/distance/fused_l2_nn.cuh +++ b/cpp/include/raft/distance/fused_l2_nn.cuh @@ -101,10 +101,6 @@ void fusedL2NN(OutT* min, bool initOutBuffer, cudaStream_t stream) { - // Assigning -1 to unsigned integers results in a compiler error. - // Enforce a signed IdxT here with a clear error message. - static_assert(std::is_signed_v, "fusedL2NN only supports signed index types."); - // When k is smaller than 32, the Policy4x4 results in redundant calculations // as it uses tiles that have k=32. Therefore, use a "skinny" policy instead // that uses tiles with a smaller value of k. diff --git a/cpp/include/raft/linalg/coalesced_reduction.cuh b/cpp/include/raft/linalg/coalesced_reduction.cuh index 518667c5f1..6ef0d52e62 100644 --- a/cpp/include/raft/linalg/coalesced_reduction.cuh +++ b/cpp/include/raft/linalg/coalesced_reduction.cuh @@ -70,7 +70,8 @@ void coalescedReduction(OutType* dots, ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { - detail::coalescedReduction(dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + detail::coalescedReduction( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } /** diff --git a/cpp/include/raft/linalg/detail/norm.cuh b/cpp/include/raft/linalg/detail/norm.cuh index 03d03497e9..a0b557211c 100644 --- a/cpp/include/raft/linalg/detail/norm.cuh +++ b/cpp/include/raft/linalg/detail/norm.cuh @@ -37,32 +37,32 @@ void rowNormCaller(Type* dots, { switch (type) { case L1Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - true, - stream, - false, - raft::L1Op(), - raft::Sum(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (Type)0, + rowMajor, + true, + stream, + false, + raft::L1Op(), + raft::Sum(), + fin_op); break; case L2Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - true, - stream, - false, - raft::L2Op(), - raft::Sum(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (Type)0, + rowMajor, + true, + stream, + false, + raft::L2Op(), + raft::Sum(), + fin_op); break; default: ASSERT(false, "Invalid norm type passed! [%d]", type); }; @@ -80,32 +80,32 @@ void colNormCaller(Type* dots, { switch (type) { case L1Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - false, - stream, - false, - raft::L1Op(), - raft::Sum(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (Type)0, + rowMajor, + false, + stream, + false, + raft::L1Op(), + raft::Sum(), + fin_op); break; case L2Norm: - raft::linalg::reduce(dots, - data, - D, - N, - (Type)0, - rowMajor, - false, - stream, - false, - raft::L2Op(), - raft::Sum(), - fin_op); + raft::linalg::reduce(dots, + data, + D, + N, + (Type)0, + rowMajor, + false, + stream, + false, + raft::L2Op(), + raft::Sum(), + fin_op); break; default: ASSERT(false, "Invalid norm type passed! [%d]", type); }; diff --git a/cpp/include/raft/linalg/detail/reduce.cuh b/cpp/include/raft/linalg/detail/reduce.cuh index cc86716a8d..3022973b43 100644 --- a/cpp/include/raft/linalg/detail/reduce.cuh +++ b/cpp/include/raft/linalg/detail/reduce.cuh @@ -44,16 +44,16 @@ void reduce(OutType* dots, FinalLambda final_op = raft::Nop()) { if (rowMajor && alongRows) { - raft::linalg::coalescedReduction( + raft::linalg::coalescedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (rowMajor && !alongRows) { - raft::linalg::stridedReduction( + raft::linalg::stridedReduction( dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } else if (!rowMajor && alongRows) { - raft::linalg::stridedReduction( + raft::linalg::stridedReduction( dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); } else { - raft::linalg::coalescedReduction( + raft::linalg::coalescedReduction( dots, data, N, D, init, stream, inplace, main_op, reduce_op, final_op); } } diff --git a/cpp/include/raft/linalg/reduce.cuh b/cpp/include/raft/linalg/reduce.cuh index 9c349ccb4f..9b3f4ee347 100644 --- a/cpp/include/raft/linalg/reduce.cuh +++ b/cpp/include/raft/linalg/reduce.cuh @@ -75,7 +75,7 @@ void reduce(OutType* dots, ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { - detail::reduce( + detail::reduce( dots, data, D, N, init, rowMajor, alongRows, stream, inplace, main_op, reduce_op, final_op); } diff --git a/cpp/include/raft/linalg/strided_reduction.cuh b/cpp/include/raft/linalg/strided_reduction.cuh index 6927269821..9147692c03 100644 --- a/cpp/include/raft/linalg/strided_reduction.cuh +++ b/cpp/include/raft/linalg/strided_reduction.cuh @@ -71,7 +71,8 @@ void stridedReduction(OutType* dots, ReduceLambda reduce_op = raft::Sum(), FinalLambda final_op = raft::Nop()) { - detail::stridedReduction(dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); + detail::stridedReduction( + dots, data, D, N, init, stream, inplace, main_op, reduce_op, final_op); } /** diff --git a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh index f64c5549a4..6d3289e14c 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_kmeans_balanced.cuh @@ -18,12 +18,17 @@ #include "ann_utils.cuh" +#include +#include + #include #include #include #include #include +#include #include +#include #include #include #include @@ -45,67 +50,108 @@ constexpr static inline const float kAdjustCentersWeight = 7.0f; * NB: no minibatch splitting is done here, it may require large amount of temporary memory (n_rows * * n_cluster * sizeof(float)). * + * @tparam IdxT index type + * @tparam LabelT label type + * * @param handle * @param[in] centers a pointer to the row-major matrix of cluster centers [n_clusters, dim] * @param n_clusters number of clusters/centers * @param dim dimensionality of the data * @param[in] dataset a pointer to the data [n_rows, dim] + * @param[in] dataset_norm pointer to the precomputed norm (for L2 metrics only) [n_rows] * @param n_rows number samples in the `dataset` * @param[out] labels output predictions [n_rows] * @param metric * @param stream * @param mr (optional) memory resource to use for temporary allocations */ +template inline void predict_float_core(const handle_t& handle, const float* centers, uint32_t n_clusters, uint32_t dim, const float* dataset, - size_t n_rows, - uint32_t* labels, + const float* dataset_norm, + IdxT n_rows, + LabelT* labels, raft::distance::DistanceType metric, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* mr) { - rmm::device_uvector distances(n_rows * n_clusters, stream, mr); - - float alpha; - float beta; switch (metric) { - case raft::distance::DistanceType::InnerProduct: { - alpha = -1.0; - beta = 0.0; - } break; case raft::distance::DistanceType::L2Expanded: - case raft::distance::DistanceType::L2Unexpanded: { - rmm::device_uvector sqsum_centers(n_clusters, stream, mr); - rmm::device_uvector sqsum_data(n_rows, stream, mr); - utils::dots_along_rows(n_clusters, dim, centers, sqsum_centers.data(), stream); - utils::dots_along_rows(n_rows, dim, dataset, sqsum_data.data(), stream); - utils::outer_add( - sqsum_data.data(), n_rows, sqsum_centers.data(), n_clusters, distances.data(), stream); - alpha = -2.0; - beta = 1.0; - } break; - // NB: update the description of `knn::ivf_flat::build` when adding here a new metric. - default: RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + case raft::distance::DistanceType::L2SqrtExpanded: { + auto workspace = raft::make_device_mdarray( + handle, mr, make_extents((sizeof(int)) * n_rows)); + + auto minClusterAndDistance = raft::make_device_mdarray, IdxT>( + handle, mr, make_extents(n_rows)); + cub::KeyValuePair initial_value(0, std::numeric_limits::max()); + thrust::fill(handle.get_thrust_policy(), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + minClusterAndDistance.size(), + initial_value); + + auto centroidsNorm = + raft::make_device_mdarray(handle, mr, make_extents(n_clusters)); + raft::linalg::rowNorm( + centroidsNorm.data_handle(), centers, dim, n_clusters, raft::linalg::L2Norm, true, stream); + + raft::distance::fusedL2NNMinReduce, IdxT>( + minClusterAndDistance.data_handle(), + dataset, + centers, + dataset_norm, + centroidsNorm.data_handle(), + n_rows, + n_clusters, + dim, + (void*)workspace.data_handle(), + (metric == raft::distance::DistanceType::L2Expanded) ? false : true, + false, + stream); + + // todo(lsugy): use KVP + iterator in caller. + // Copy keys to output labels + thrust::transform(handle.get_thrust_policy(), + minClusterAndDistance.data_handle(), + minClusterAndDistance.data_handle() + n_rows, + labels, + [=] __device__(cub::KeyValuePair kvp) { + return static_cast(kvp.key); + }); + break; + } + case raft::distance::DistanceType::InnerProduct: { + // TODO: pass buffer + rmm::device_uvector distances(n_rows * n_clusters, stream, mr); + + float alpha = -1.0; + float beta = 0.0; + + linalg::gemm(handle, + true, + false, + n_clusters, + n_rows, + dim, + &alpha, + centers, + dim, + dataset, + dim, + &beta, + distances.data(), + n_clusters, + stream); + utils::argmin_along_rows( + n_rows, static_cast(n_clusters), distances.data(), labels, stream); + break; + } + default: { + RAFT_FAIL("The chosen distance metric is not supported (%d)", int(metric)); + } } - linalg::gemm(handle, - true, - false, - n_clusters, - n_rows, - dim, - &alpha, - centers, - dim, - dataset, - dim, - &beta, - distances.data(), - n_clusters, - stream); - utils::argmin_along_rows(n_rows, n_clusters, distances.data(), labels, stream); } /** @@ -118,16 +164,32 @@ inline void predict_float_core(const handle_t& handle, * @param n_rows dataset size * @return a suggested minibatch size */ -constexpr inline auto calc_minibatch_size(uint32_t n_clusters, size_t n_rows) -> uint32_t +template +constexpr inline auto calc_minibatch_size(uint32_t n_clusters, + IdxT n_rows, + uint32_t dim, + raft::distance::DistanceType metric, + bool is_float) -> IdxT { - n_clusters = std::max(1, n_clusters); - uint32_t minibatch_size = (1 << 20); - if (minibatch_size > (1 << 28) / n_clusters) { - minibatch_size = (1 << 28) / n_clusters; - minibatch_size += 32; - minibatch_size -= minibatch_size % 64; + n_clusters = std::max(1, n_clusters); + + // Estimate memory needs per row (i.e element of the batch). + IdxT mem_per_row = 0; + /* fusedL2NN only needs one integer per row for a mutex. + * Other metrics require storing a distance matrix. */ + if (metric != raft::distance::DistanceType::L2Expanded && + metric != raft::distance::DistanceType::L2SqrtExpanded) { + mem_per_row += sizeof(float) * n_clusters; + } else { + mem_per_row += sizeof(int); } - minibatch_size = uint32_t(std::min(minibatch_size, n_rows)); + // If we need to convert to float, space required for the converted batch. + if (!is_float) { mem_per_row += sizeof(float) * dim; } + + // Heuristic: calculate the minibatch size in order to use at most 1GB of memory. + IdxT minibatch_size = (1 << 30) / mem_per_row; + minibatch_size = 64 * ceildiv(minibatch_size, (IdxT)64); + minibatch_size = std::min(minibatch_size, n_rows); return minibatch_size; } @@ -154,7 +216,9 @@ constexpr inline auto calc_minibatch_size(uint32_t n_clusters, size_t n_rows) -> * 1. All pointers are on the device. * 2. All pointers are on the host, but `centers` and `cluster_sizes` are accessible from GPU. * - * @tparam T element type + * @tparam T element type + * @tparam IdxT index type + * @tparam LabelT label type * * @param[inout] centers pointer to the output [n_clusters, dim] * @param[inout] cluster_sizes number of rows in each cluster [n_clusters] @@ -168,14 +232,14 @@ constexpr inline auto calc_minibatch_size(uint32_t n_clusters, size_t n_rows) -> * the weighted average principle. * @param stream */ -template +template void calc_centers_and_sizes(float* centers, uint32_t* cluster_sizes, uint32_t n_clusters, uint32_t dim, const T* dataset, - size_t n_rows, - const uint32_t* labels, + IdxT n_rows, + const LabelT* labels, bool reset_counters, rmm::cuda_stream_view stream) { @@ -201,76 +265,125 @@ void calc_centers_and_sizes(float* centers, stream); } +/** Computes the L2 norm of the dataset, converting to float if necessary */ +template +void compute_norm(float* dataset_norm, + const T* dataset, + IdxT dim, + IdxT n_rows, + rmm::cuda_stream_view stream, + rmm::mr::device_memory_resource* mr = nullptr) +{ + common::nvtx::range fun_scope("kmeans::compute_norm"); + if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } + rmm::device_uvector dataset_float(0, stream, mr); + + const float* dataset_ptr = nullptr; + + if (std::is_same_v) { + dataset_ptr = reinterpret_cast(dataset); + } else { + dataset_float.resize(n_rows * dim, stream); + + linalg::unaryOp(dataset_float.data(), dataset, n_rows * dim, utils::mapping{}, stream); + + dataset_ptr = (const float*)dataset_float.data(); + } + + raft::linalg::rowNorm( + dataset_norm, dataset_ptr, dim, n_rows, raft::linalg::L2Norm, true, stream); +} + /** * @brief Predict labels for the dataset. * - * @tparam T element type + * @tparam T element type + * @tparam IdxT index type + * @tparam LabelT label type * * @param handle * @param[in] centers a pointer to the row-major matrix of cluster centers [n_clusters, dim] * @param n_clusters number of clusters/centers * @param dim dimensionality of the data * @param[in] dataset a pointer to the data [n_rows, dim] + * @param[in] dataset_norm pointer to the precomputed norm (for L2 metrics only) [n_rows] * @param n_rows number samples in the `dataset` * @param[out] labels output predictions [n_rows] * @param metric * @param stream * @param mr (optional) memory resource to use for temporary allocations */ - -template +template void predict(const handle_t& handle, const float* centers, uint32_t n_clusters, uint32_t dim, const T* dataset, - size_t n_rows, - uint32_t* labels, + IdxT n_rows, + LabelT* labels, raft::distance::DistanceType metric, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* mr = nullptr) + rmm::mr::device_memory_resource* mr = nullptr, + const float* dataset_norm = nullptr) { common::nvtx::range fun_scope( - "kmeans::predict(%zu, %u)", n_rows, n_clusters); + "kmeans::predict(%zu, %u)", static_cast(n_rows), n_clusters); if (mr == nullptr) { mr = rmm::mr::get_current_device_resource(); } - const uint32_t max_minibatch_size = calc_minibatch_size(n_clusters, n_rows); + IdxT max_minibatch_size = + calc_minibatch_size(n_clusters, n_rows, dim, metric, std::is_same_v); rmm::device_uvector cur_dataset( std::is_same_v ? 0 : max_minibatch_size * dim, stream, mr); - auto cur_dataset_ptr = cur_dataset.data(); - for (size_t offset = 0; offset < n_rows; offset += max_minibatch_size) { - auto minibatch_size = std::min(max_minibatch_size, n_rows - offset); + bool need_compute_norm = + dataset_norm == nullptr && (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded); + rmm::device_uvector cur_dataset_norm( + need_compute_norm ? max_minibatch_size : 0, stream, mr); + const float* dataset_norm_ptr = nullptr; + auto cur_dataset_ptr = cur_dataset.data(); + for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { + IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); if constexpr (std::is_same_v) { cur_dataset_ptr = const_cast(dataset + offset * dim); } else { linalg::unaryOp(cur_dataset_ptr, dataset + offset * dim, - minibatch_size * dim, + (IdxT)(minibatch_size * dim), utils::mapping{}, stream); } - predict_float_core(handle, - centers, - n_clusters, - dim, - cur_dataset_ptr, - minibatch_size, - labels + offset, - metric, - stream, - mr); + // Compute the norm now if it hasn't been pre-computed. + if (need_compute_norm) { + compute_norm( + cur_dataset_norm.data(), cur_dataset_ptr, (IdxT)dim, (IdxT)minibatch_size, stream, mr); + dataset_norm_ptr = cur_dataset_norm.data(); + } else if (dataset_norm != nullptr) { + dataset_norm_ptr = dataset_norm + offset; + } + + predict_float_core(handle, + centers, + n_clusters, + dim, + cur_dataset_ptr, + dataset_norm_ptr, + minibatch_size, + labels + offset, + metric, + stream, + mr); } } -template +template __global__ void __launch_bounds__((WarpSize * BlockDimY)) adjust_centers_kernel(float* centers, // [n_clusters, dim] uint32_t n_clusters, uint32_t dim, const T* dataset, // [n_rows, dim] - size_t n_rows, - const uint32_t* labels, // [n_rows] + IdxT n_rows, + const LabelT* labels, // [n_rows] const uint32_t* cluster_sizes, // [n_clusters] float threshold, uint32_t average, @@ -284,11 +397,11 @@ __global__ void __launch_bounds__((WarpSize * BlockDimY)) if (csize > static_cast(average * threshold)) return; // choose a "random" i that belongs to a rather large cluster - size_t i; + IdxT i; uint32_t j = laneId(); if (j == 0) { do { - auto old = static_cast(atomicAdd(count, 1)); + auto old = static_cast(atomicAdd(count, 1)); i = (seed * (old + 1)) % n_rows; } while (cluster_sizes[labels[i]] < average); } @@ -296,7 +409,7 @@ __global__ void __launch_bounds__((WarpSize * BlockDimY)) // Adjust the center of the selected smaller cluster to gravitate towards // a sample from the selected larger cluster. - const size_t li = labels[i]; + const IdxT li = static_cast(labels[i]); // Weight of the current center for the weighted average. // We dump it for anomalously small clusters, but keep constant overwise. const float wc = csize > kAdjustCentersWeight ? kAdjustCentersWeight : float(csize); @@ -305,7 +418,7 @@ __global__ void __launch_bounds__((WarpSize * BlockDimY)) for (; j < dim; j += WarpSize) { float val = 0; val += wc * centers[j + dim * li]; - val += wd * utils::mapping{}(dataset[j + size_t(dim) * i]); + val += wd * utils::mapping{}(dataset[j + static_cast(dim) * i]); val /= wc + wd; centers[j + dim * l] = val; } @@ -338,30 +451,30 @@ __global__ void __launch_bounds__((WarpSize * BlockDimY)) * * @return whether any of the centers has been updated (and thus, `labels` need to be recalculated). */ -template +template auto adjust_centers(float* centers, uint32_t n_clusters, uint32_t dim, const T* dataset, - size_t n_rows, - const uint32_t* labels, + IdxT n_rows, + const LabelT* labels, const uint32_t* cluster_sizes, float threshold, rmm::cuda_stream_view stream, rmm::mr::device_memory_resource* device_memory) -> bool { common::nvtx::range fun_scope( - "kmeans::adjust_centers(%zu, %u)", n_rows, n_clusters); + "kmeans::adjust_centers(%zu, %u)", static_cast(n_rows), n_clusters); if (n_clusters == 0) { return false; } constexpr static std::array kPrimes{29, 71, 113, 173, 229, 281, 349, 409, 463, 541, 601, 659, 733, 809, 863, 941, 1013, 1069, 1151, 1223, 1291, 1373, 1451, 1511, 1583, 1657, 1733, 1811, 1889, 1987, 2053, 2129, 2213, 2287, 2357, 2423, 2531, 2617, 2687, 2741}; - static size_t i = 0; - static size_t i_primes = 0; + static IdxT i = 0; + static IdxT i_primes = 0; bool adjusted = false; - uint32_t average = static_cast(n_rows / size_t(n_clusters)); + uint32_t average = static_cast(n_rows / static_cast(n_clusters)); uint32_t ofst; do { i_primes = (i_primes + 1) % kPrimes.size(); @@ -400,7 +513,7 @@ auto adjust_centers(float* centers, } while (cluster_sizes[labels[i]] < average); // Adjust the center of the selected smaller cluster to gravitate towards // a sample from the selected larger cluster. - const size_t li = labels[i]; + const IdxT li = static_cast(labels[i]); // Weight of the current center for the weighted average. // We dump it for anomalously small clusters, but keep constant overwise. const float wc = std::min(csize, kAdjustCentersWeight); @@ -409,7 +522,7 @@ auto adjust_centers(float* centers, for (uint32_t j = 0; j < dim; j++) { float val = 0; val += wc * centers[j + dim * li]; - val += wd * utils::mapping{}(dataset[j + size_t(dim) * i]); + val += wd * utils::mapping{}(dataset[j + static_cast(dim) * i]); val /= wc + wd; centers[j + dim * l] = val; } @@ -429,12 +542,15 @@ auto adjust_centers(float* centers, * Thus, this function can be used for fine-tuning existing clusters; * to train from scratch, use `build_clusters` function below. * - * @tparam T element type + * @tparam T element type + * @tparam IdxT index type + * @tparam LabelT label type * * @param handle * @param n_iters the requested number of iteration * @param dim the dimensionality of the dataset * @param[in] dataset a pointer to a managed row-major array [n_rows, dim] + * @param[in] dataset_norm pointer to the precomputed norm (for L2 metrics only) [n_rows] * @param n_rows the number of rows in the dataset * @param n_cluster the requested number of clusters * @param[inout] cluster_centers a pointer to a managed row-major array [n_clusters, dim] @@ -453,15 +569,16 @@ auto adjust_centers(float* centers, * @param device_memory * a memory resource for device allocations (makes sense to provide a memory pool here) */ -template +template void balancing_em_iters(const handle_t& handle, uint32_t n_iters, uint32_t dim, const T* dataset, - size_t n_rows, + const float* dataset_norm, + IdxT n_rows, uint32_t n_clusters, float* cluster_centers, - uint32_t* cluster_labels, + LabelT* cluster_labels, uint32_t* cluster_sizes, raft::distance::DistanceType metric, uint32_t balancing_pullback, @@ -498,16 +615,17 @@ void balancing_em_iters(const handle_t& handle, default: break; } // E: Expectation step - predict labels - predict(handle, - cluster_centers, - n_clusters, - dim, - dataset, - n_rows, - cluster_labels, - metric, - stream, - device_memory); + predict(handle, + cluster_centers, + n_clusters, + dim, + dataset, + n_rows, + cluster_labels, + metric, + stream, + device_memory, + dataset_norm); // M: Maximization step - calculate optimal cluster centers calc_centers_and_sizes(cluster_centers, cluster_sizes, @@ -522,51 +640,58 @@ void balancing_em_iters(const handle_t& handle, } /** Randomly initialize cluster centers and then call `balancing_em_iters`. */ -template +template void build_clusters(const handle_t& handle, uint32_t n_iters, uint32_t dim, const T* dataset, - size_t n_rows, + IdxT n_rows, uint32_t n_clusters, float* cluster_centers, - uint32_t* cluster_labels, + LabelT* cluster_labels, uint32_t* cluster_sizes, raft::distance::DistanceType metric, rmm::cuda_stream_view stream, - rmm::mr::device_memory_resource* device_memory) + rmm::mr::device_memory_resource* device_memory, + const float* dataset_norm = nullptr) { + RAFT_EXPECTS(static_cast(n_rows) * static_cast(dim) <= + static_cast(std::numeric_limits::max()), + "the chosen index type cannot represent all indices for the given dataset"); + // "randomly initialize labels" - auto f = [n_clusters] __device__(uint32_t * out, size_t i) { - *out = uint32_t(i % size_t(n_clusters)); + auto f = [n_clusters] __device__(LabelT * out, IdxT i) { + *out = LabelT(i % static_cast(n_clusters)); }; - linalg::writeOnlyUnaryOp(cluster_labels, n_rows, f, stream); + linalg::writeOnlyUnaryOp(cluster_labels, n_rows, f, stream); // update centers to match the initialized labels. calc_centers_and_sizes( cluster_centers, cluster_sizes, n_clusters, dim, dataset, n_rows, cluster_labels, true, stream); // run EM - balancing_em_iters(handle, - n_iters, - dim, - dataset, - n_rows, - n_clusters, - cluster_centers, - cluster_labels, - cluster_sizes, - metric, - 2, - 0.25f, - stream, - device_memory); + balancing_em_iters(handle, + n_iters, + dim, + dataset, + dataset_norm, + n_rows, + n_clusters, + cluster_centers, + cluster_labels, + cluster_sizes, + metric, + 2, + 0.25f, + stream, + device_memory); } /** Calculate how many fine clusters should belong to each mesocluster. */ +template inline auto arrange_fine_clusters(uint32_t n_clusters, uint32_t n_mesoclusters, - size_t n_rows, + IdxT n_rows, const uint32_t* mesocluster_sizes) { std::vector fine_clusters_nums(n_mesoclusters); @@ -578,8 +703,8 @@ inline auto arrange_fine_clusters(uint32_t n_clusters, for (uint32_t i = 0; i < n_mesoclusters; i++) { n_nonempty_ms_rem += mesocluster_sizes[i] > 0 ? 1 : 0; } - size_t n_rows_rem = n_rows; - size_t mesocluster_size_sum = 0; + IdxT n_rows_rem = n_rows; + IdxT mesocluster_size_sum = 0; uint32_t mesocluster_size_max = 0; uint32_t fine_clusters_nums_max = 0; for (uint32_t i = 0; i < n_mesoclusters; i++) { @@ -609,14 +734,12 @@ inline auto arrange_fine_clusters(uint32_t n_clusters, RAFT_EXPECTS(mesocluster_size_sum == n_rows, "mesocluster sizes do not add up (%zu) to the total trainset size (%zu)", - mesocluster_size_sum, - n_rows); + static_cast(mesocluster_size_sum), + static_cast(n_rows)); RAFT_EXPECTS(fine_clusters_csum[n_mesoclusters] == n_clusters, "fine cluster numbers do not add up (%u) to the total number of clusters (%u)", fine_clusters_csum[n_mesoclusters], - n_clusters - - ); + n_clusters); return std::make_tuple(mesocluster_size_max, fine_clusters_nums_max, @@ -637,13 +760,14 @@ inline auto arrange_fine_clusters(uint32_t n_clusters, * this function returns the total number of fine clusters, which can be checked to be * the same as the requested number of clusters. */ -template +template auto build_fine_clusters(const handle_t& handle, uint32_t n_iters, uint32_t dim, const T* dataset_mptr, - const uint32_t* labels_mptr, - size_t n_rows, + const float* dataset_norm_mptr, + const LabelT* labels_mptr, + IdxT n_rows, const uint32_t* fine_clusters_nums, const uint32_t* fine_clusters_csum, const uint32_t* mesocluster_sizes, @@ -656,13 +780,15 @@ auto build_fine_clusters(const handle_t& handle, rmm::mr::device_memory_resource* device_memory, rmm::cuda_stream_view stream) -> uint32_t { - rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); + rmm::device_uvector mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory); rmm::device_uvector mc_trainset_buf(mesocluster_size_max * dim, stream, device_memory); - auto mc_trainset_ids = mc_trainset_ids_buf.data(); - auto mc_trainset = mc_trainset_buf.data(); + rmm::device_uvector mc_trainset_norm_buf(mesocluster_size_max, stream, device_memory); + auto mc_trainset_ids = mc_trainset_ids_buf.data(); + auto mc_trainset = mc_trainset_buf.data(); + auto mc_trainset_norm = mc_trainset_norm_buf.data(); // label (cluster ID) of each vector - rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); + rmm::device_uvector mc_trainset_labels(mesocluster_size_max, stream, device_memory); rmm::device_uvector mc_trainset_ccenters( fine_clusters_nums_max * dim, stream, device_memory); @@ -674,10 +800,11 @@ auto build_fine_clusters(const handle_t& handle, uint32_t n_clusters_done = 0; for (uint32_t i = 0; i < n_mesoclusters; i++) { uint32_t k = 0; - for (size_t j = 0; j < n_rows; j++) { - if (labels_mptr[j] == i) { mc_trainset_ids[k++] = j; } + for (IdxT j = 0; j < n_rows; j++) { + if (labels_mptr[j] == (LabelT)i) { mc_trainset_ids[k++] = j; } } - RAFT_EXPECTS(k == mesocluster_sizes[i], "Incorrect mesocluster size at %d.", i); + if (k != mesocluster_sizes[i]) + RAFT_LOG_WARN("Incorrect mesocluster size at %d. %d vs %d", i, k, mesocluster_sizes[i]); if (k == 0) { RAFT_LOG_DEBUG("Empty cluster %d", i); RAFT_EXPECTS(fine_clusters_nums[i] == 0, @@ -689,21 +816,36 @@ auto build_fine_clusters(const handle_t& handle, "Number of fine clusters must be non-zero for a non-empty mesocluster"); } - utils::copy_selected( - mesocluster_sizes[i], dim, dataset_mptr, mc_trainset_ids, dim, mc_trainset, dim, stream); + utils::copy_selected((IdxT)mesocluster_sizes[i], + (IdxT)dim, + dataset_mptr, + mc_trainset_ids, + (IdxT)dim, + mc_trainset, + (IdxT)dim, + stream); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + thrust::gather(handle.get_thrust_policy(), + mc_trainset_ids, + mc_trainset_ids + mesocluster_sizes[i], + dataset_norm_mptr, + mc_trainset_norm); + } - build_clusters(handle, - n_iters, - dim, - mc_trainset, - mesocluster_sizes[i], - fine_clusters_nums[i], - mc_trainset_ccenters.data(), - mc_trainset_labels.data(), - mc_trainset_csizes_tmp.data(), - metric, - stream, - device_memory); + build_clusters(handle, + n_iters, + dim, + mc_trainset, + mesocluster_sizes[i], + fine_clusters_nums[i], + mc_trainset_ccenters.data(), + mc_trainset_labels.data(), + mc_trainset_csizes_tmp.data(), + metric, + stream, + device_memory, + mc_trainset_norm); raft::copy(cluster_centers + (dim * fine_clusters_csum[i]), mc_trainset_ccenters.data(), @@ -718,7 +860,9 @@ auto build_fine_clusters(const handle_t& handle, /** * @brief Hierarchical balanced k-means * - * @tparam T element type + * @tparam T element type + * @tparam IdxT index type + * @tparam LabelT label type * * @param handle * @param n_iters number of training iterations @@ -730,50 +874,76 @@ auto build_fine_clusters(const handle_t& handle, * @param metric the distance type * @param stream */ -template +template void build_hierarchical(const handle_t& handle, uint32_t n_iters, uint32_t dim, const T* dataset, - size_t n_rows, + IdxT n_rows, float* cluster_centers, uint32_t n_clusters, raft::distance::DistanceType metric, rmm::cuda_stream_view stream) { + using LabelT = uint32_t; + + RAFT_EXPECTS(static_cast(n_rows) * static_cast(dim) <= + static_cast(std::numeric_limits::max()), + "the chosen index type cannot represent all indices for the given dataset"); + common::nvtx::range fun_scope( - "kmeans::build_hierarchical(%zu, %u)", n_rows, n_clusters); + "kmeans::build_hierarchical(%zu, %u)", static_cast(n_rows), n_clusters); uint32_t n_mesoclusters = std::min(n_clusters, std::sqrt(n_clusters) + 0.5); RAFT_LOG_DEBUG("kmeans::build_hierarchical: n_mesoclusters: %u", n_mesoclusters); rmm::mr::managed_memory_resource managed_memory; rmm::mr::device_memory_resource* device_memory = nullptr; - auto pool_guard = raft::get_pool_memory_resource( - device_memory, kmeans::calc_minibatch_size(n_mesoclusters, n_rows) * dim * 4); + IdxT max_minibatch_size = + calc_minibatch_size(n_clusters, n_rows, dim, metric, std::is_same_v); + auto pool_guard = raft::get_pool_memory_resource(device_memory, max_minibatch_size * dim * 4); if (pool_guard) { RAFT_LOG_DEBUG( "kmeans::build_hierarchical: using pool memory resource with initial size %zu bytes", pool_guard->pool_size()); } + // Precompute the L2 norm of the dataset if relevant. + const float* dataset_norm = nullptr; + rmm::device_uvector dataset_norm_buf(0, stream, device_memory); + if (metric == raft::distance::DistanceType::L2Expanded || + metric == raft::distance::DistanceType::L2SqrtExpanded) { + dataset_norm_buf.resize(n_rows, stream); + for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) { + IdxT minibatch_size = std::min(max_minibatch_size, n_rows - offset); + compute_norm(dataset_norm_buf.data() + offset, + dataset + dim * offset, + (IdxT)dim, + (IdxT)minibatch_size, + stream, + device_memory); + } + dataset_norm = (const float*)dataset_norm_buf.data(); + } + // build coarse clusters (mesoclusters) - rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); + rmm::device_uvector mesocluster_labels_buf(n_rows, stream, &managed_memory); rmm::device_uvector mesocluster_sizes_buf(n_mesoclusters, stream, &managed_memory); { rmm::device_uvector mesocluster_centers_buf(n_mesoclusters * dim, stream, device_memory); - build_clusters(handle, - n_iters, - dim, - dataset, - n_rows, - n_mesoclusters, - mesocluster_centers_buf.data(), - mesocluster_labels_buf.data(), - mesocluster_sizes_buf.data(), - metric, - stream, - device_memory); + build_clusters(handle, + n_iters, + dim, + dataset, + n_rows, + n_mesoclusters, + mesocluster_centers_buf.data(), + mesocluster_labels_buf.data(), + mesocluster_sizes_buf.data(), + metric, + stream, + device_memory, + dataset_norm); } auto mesocluster_sizes = mesocluster_sizes_buf.data(); @@ -791,27 +961,28 @@ void build_hierarchical(const handle_t& handle, RAFT_LOG_TRACE_VEC(fine_clusters_nums.data(), n_mesoclusters); } - auto n_clusters_done = build_fine_clusters(handle, - n_iters, - dim, - dataset, - mesocluster_labels, - n_rows, - fine_clusters_nums.data(), - fine_clusters_csum.data(), - mesocluster_sizes, - n_mesoclusters, - mesocluster_size_max, - fine_clusters_nums_max, - cluster_centers, - metric, - &managed_memory, - device_memory, - stream); + auto n_clusters_done = build_fine_clusters(handle, + n_iters, + dim, + dataset, + dataset_norm, + mesocluster_labels, + n_rows, + fine_clusters_nums.data(), + fine_clusters_csum.data(), + mesocluster_sizes, + n_mesoclusters, + mesocluster_size_max, + fine_clusters_nums_max, + cluster_centers, + metric, + &managed_memory, + device_memory, + stream); RAFT_EXPECTS(n_clusters_done == n_clusters, "Didn't process all clusters."); rmm::device_uvector cluster_sizes(n_clusters, stream, device_memory); - rmm::device_uvector labels(n_rows, stream, device_memory); + rmm::device_uvector labels(n_rows, stream, device_memory); // Fine-tuning kmeans for all clusters // @@ -821,20 +992,21 @@ void build_hierarchical(const handle_t& handle, // is a possibility that the clusters could be unbalanced here, // in which case the actual number of iterations would be increased. // - balancing_em_iters(handle, - std::max(n_iters / 10, 2), - dim, - dataset, - n_rows, - n_clusters, - cluster_centers, - labels.data(), - cluster_sizes.data(), - metric, - 5, - 0.2f, - stream, - device_memory); + balancing_em_iters(handle, + std::max(n_iters / 10, 2), + dim, + dataset, + dataset_norm, + n_rows, + n_clusters, + cluster_centers, + labels.data(), + cluster_sizes.data(), + metric, + 5, + 0.2f, + stream, + device_memory); } } // namespace raft::spatial::knn::detail::kmeans diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index a48fad2737..8dda574314 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -135,8 +135,8 @@ struct mapping { * @param[in] value * @param[in] n_bytes */ -template -inline void memzero(T* ptr, size_t n_elems, rmm::cuda_stream_view stream) +template +inline void memzero(T* ptr, IdxT n_elems, rmm::cuda_stream_view stream) { switch (check_pointer_residency(ptr)) { case pointer_residency::host_and_device: @@ -152,7 +152,7 @@ inline void memzero(T* ptr, size_t n_elems, rmm::cuda_stream_view stream) } template -__global__ void argmin_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, OutT* out) +__global__ void argmin_along_rows_kernel(IdxT n_rows, uint32_t n_cols, const float* a, OutT* out) { __shared__ OutT shm_ids[1024]; // NOLINT __shared__ float shm_vals[1024]; // NOLINT @@ -214,7 +214,7 @@ inline void argmin_along_rows( template __global__ void dots_along_rows_kernel(IdxT n_rows, IdxT n_cols, const float* a, float* out) { - IdxT i = threadIdx.y + (blockDim.y * blockIdx.x); + IdxT i = threadIdx.y + (blockDim.y * static_cast(blockIdx.x)); if (i >= n_rows) return; float sqsum = 0.0; @@ -259,19 +259,19 @@ inline void dots_along_rows( */ } -template -__global__ void accumulate_into_selected_kernel(uint32_t n_rows, +template +__global__ void accumulate_into_selected_kernel(IdxT n_rows, uint32_t n_cols, float* output, uint32_t* selection_counters, const T* input, - const uint32_t* row_ids) + const LabelT* row_ids) { - uint64_t gid = threadIdx.x + (blockDim.x * static_cast(blockIdx.x)); - uint64_t j = gid % n_cols; - uint64_t i = gid / n_cols; + IdxT gid = threadIdx.x + (blockDim.x * static_cast(blockIdx.x)); + IdxT j = gid % n_cols; + IdxT i = gid / n_cols; if (i >= n_rows) return; - uint64_t l = row_ids[i]; + IdxT l = static_cast(row_ids[i]); if (j == 0) { atomicAdd(&(selection_counters[l]), 1); } atomicAdd(&(output[j + n_cols * l]), mapping{}(input[gid])); } @@ -281,7 +281,9 @@ __global__ void accumulate_into_selected_kernel(uint32_t n_rows, * (cast and possibly scale the data input type). Count the number of times every output * row was selected along the way. * - * @tparam T + * @tparam T element type + * @tparam IdxT index type + * @tparam LabelT label type * * @param n_cols number of columns in all matrices * @param[out] output output matrix [..., n_cols] @@ -290,13 +292,13 @@ __global__ void accumulate_into_selected_kernel(uint32_t n_rows, * @param[in] input row-major input matrix [n_rows, n_cols] * @param[in] row_ids row indices in the output matrix [n_rows] */ -template -void accumulate_into_selected(size_t n_rows, +template +void accumulate_into_selected(IdxT n_rows, uint32_t n_cols, float* output, uint32_t* selection_counters, const T* input, - const uint32_t* row_ids, + const LabelT* row_ids, rmm::cuda_stream_view stream) { switch (check_pointer_residency(output, input, selection_counters, row_ids)) { @@ -304,16 +306,16 @@ void accumulate_into_selected(size_t n_rows, case pointer_residency::device_only: { uint32_t block_dim = 128; auto grid_dim = - static_cast(ceildiv(n_rows * static_cast(n_cols), block_dim)); + static_cast(ceildiv(n_rows * static_cast(n_cols), block_dim)); accumulate_into_selected_kernel<<>>( n_rows, n_cols, output, selection_counters, input, row_ids); } break; case pointer_residency::host_only: { stream.synchronize(); - for (size_t i = 0; i < n_rows; i++) { - uint32_t l = row_ids[i]; + for (IdxT i = 0; i < n_rows; i++) { + IdxT l = static_cast(row_ids[i]); selection_counters[l]++; - for (uint32_t j = 0; j < n_cols; j++) { + for (IdxT j = 0; j < n_cols; j++) { output[j + n_cols * l] += mapping{}(input[j + n_cols * i]); } } @@ -326,7 +328,7 @@ void accumulate_into_selected(size_t n_rows, template __global__ void normalize_rows_kernel(IdxT n_rows, IdxT n_cols, float* a) { - uint64_t i = threadIdx.y + (blockDim.y * blockIdx.x); + IdxT i = threadIdx.y + (blockDim.y * static_cast(blockIdx.x)); if (i >= n_rows) return; float sqsum = 0.0; @@ -366,12 +368,12 @@ inline void normalize_rows(IdxT n_rows, IdxT n_cols, float* a, rmm::cuda_stream_ normalize_rows_kernel<<>>(n_rows, n_cols, a); } -template +template __global__ void map_along_rows_kernel( - uint32_t n_rows, uint32_t n_cols, float* a, const uint32_t* d, Lambda map) + IdxT n_rows, uint32_t n_cols, float* a, const uint32_t* d, Lambda map) { - uint64_t gid = threadIdx.x + blockDim.x * blockIdx.x; - uint64_t i = gid / n_cols; + IdxT gid = threadIdx.x + blockDim.x * static_cast(blockIdx.x); + IdxT i = gid / n_cols; if (i >= n_rows) return; float& x = a[gid]; x = map(x, d[i]); @@ -383,6 +385,7 @@ __global__ void map_along_rows_kernel( * * NB: device-only function * + * @tparam IdxT index type * @tparam Lambda * * @param n_rows @@ -391,8 +394,8 @@ __global__ void map_along_rows_kernel( * @param[in] v device pointer to a vector [n_rows] * @param op the binary operation to apply on every element of matrix rows and of the vector */ -template -inline void map_along_rows(uint32_t n_rows, +template +inline void map_along_rows(IdxT n_rows, uint32_t n_cols, float* m, const uint32_t* v, @@ -400,19 +403,16 @@ inline void map_along_rows(uint32_t n_rows, rmm::cuda_stream_view stream) { dim3 threads(128, 1, 1); - dim3 blocks( - ceildiv(static_cast(n_rows) * static_cast(n_cols), threads.x), - 1, - 1); + dim3 blocks(ceildiv(n_rows * n_cols, threads.x), 1, 1); map_along_rows_kernel<<>>(n_rows, n_cols, m, v, op); } -template -__global__ void outer_add_kernel(const T* a, uint32_t len_a, const T* b, uint32_t len_b, T* c) +template +__global__ void outer_add_kernel(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c) { - uint64_t gid = threadIdx.x + blockDim.x * blockIdx.x; - uint64_t i = gid / len_b; - uint64_t j = gid % len_b; + IdxT gid = threadIdx.x + blockDim.x * static_cast(blockIdx.x); + IdxT i = gid / len_b; + IdxT j = gid % len_b; if (i >= len_a) return; c[gid] = (a == nullptr ? T(0) : a[i]) + (b == nullptr ? T(0) : b[j]); } @@ -425,7 +425,7 @@ __global__ void block_copy_kernel(const IdxT* in_offsets, T* out_data, IdxT n_mult) { - IdxT i = IdxT(blockDim.x) * IdxT(blockIdx.x) + threadIdx.x; + IdxT i = static_cast(blockDim.x) * static_cast(blockIdx.x) + threadIdx.x; // find the source offset using the binary search. uint32_t l = 0; uint32_t r = n_blocks; @@ -482,7 +482,8 @@ void block_copy(const IdxT* in_offsets, * * NB: device-only function * - * @tparam T element type + * @tparam T element type + * @tparam IdxT index type * * @param[in] a device pointer to a vector [len_a] * @param len_a number of elements in `a` @@ -491,32 +492,23 @@ void block_copy(const IdxT* in_offsets, * @param[out] c row-major matrix [len_a, len_b] * @param stream */ -template -void outer_add( - const T* a, uint32_t len_a, const T* b, uint32_t len_b, T* c, rmm::cuda_stream_view stream) +template +void outer_add(const T* a, IdxT len_a, const T* b, IdxT len_b, T* c, rmm::cuda_stream_view stream) { dim3 threads(128, 1, 1); - dim3 blocks( - ceildiv(static_cast(len_a) * static_cast(len_b), threads.x), - 1, - 1); + dim3 blocks(ceildiv(len_a * len_b, threads.x), 1, 1); outer_add_kernel<<>>(a, len_a, b, len_b, c); } -template -__global__ void copy_selected_kernel(uint64_t n_rows, - uint64_t n_cols, - const S* src, - const IdxT* row_ids, - uint64_t ld_src, - T* dst, - uint64_t ld_dst) +template +__global__ void copy_selected_kernel( + IdxT n_rows, IdxT n_cols, const S* src, const LabelT* row_ids, IdxT ld_src, T* dst, IdxT ld_dst) { - uint64_t gid = threadIdx.x + uint64_t{blockDim.x} * uint64_t{blockIdx.x}; - uint64_t j = gid % n_cols; - uint64_t i_dst = gid / n_cols; + IdxT gid = threadIdx.x + blockDim.x * static_cast(blockIdx.x); + IdxT j = gid % n_cols; + IdxT i_dst = gid / n_cols; if (i_dst >= n_rows) return; - auto i_src = static_cast(row_ids[i_dst]); + auto i_src = static_cast(row_ids[i_dst]); dst[ld_dst * i_dst + j] = mapping{}(src[ld_src * i_src + j]); } @@ -524,8 +516,10 @@ __global__ void copy_selected_kernel(uint64_t n_rows, * @brief Copy selected rows of a matrix while mapping the data from the source to the target * type. * - * @tparam T target type - * @tparam S source type + * @tparam T target type + * @tparam S source type + * @tparam IdxT index type + * @tparam LabelT label type * * @param n_rows * @param n_cols @@ -536,29 +530,29 @@ __global__ void copy_selected_kernel(uint64_t n_rows, * @param ld_dst number of cols in the output (ld_dst >= n_cols) * @param stream */ -template -void copy_selected(uint64_t n_rows, - uint64_t n_cols, +template +void copy_selected(IdxT n_rows, + IdxT n_cols, const S* src, - const IdxT* row_ids, - uint64_t ld_src, + const LabelT* row_ids, + IdxT ld_src, T* dst, - uint64_t ld_dst, + IdxT ld_dst, rmm::cuda_stream_view stream) { switch (check_pointer_residency(src, dst, row_ids)) { case pointer_residency::host_and_device: case pointer_residency::device_only: { - uint64_t block_dim = 128; - uint64_t grid_dim = ceildiv(n_rows * n_cols, block_dim); + IdxT block_dim = 128; + IdxT grid_dim = ceildiv(n_rows * n_cols, block_dim); copy_selected_kernel <<>>(n_rows, n_cols, src, row_ids, ld_src, dst, ld_dst); } break; case pointer_residency::host_only: { stream.synchronize(); - for (uint64_t i_dst = 0; i_dst < n_rows; i_dst++) { - auto i_src = static_cast(row_ids[i_dst]); - for (uint64_t j = 0; j < n_cols; j++) { + for (IdxT i_dst = 0; i_dst < n_rows; i_dst++) { + auto i_src = static_cast(row_ids[i_dst]); + for (IdxT j = 0; j < n_cols; j++) { dst[ld_dst * i_dst + j] = mapping{}(src[ld_src * i_src + j]); } } diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh index 2bc91d1b3b..af1cb97d36 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_build.cuh @@ -44,8 +44,9 @@ using namespace raft::spatial::knn::detail; // NOLINT * X dimension must cover the dataset (n_rows), YZ are not used; * there are no dependencies between threads, hence no constraints on the block size. * - * @tparam T the element type. - * @tparam IdxT type of the indices in the source source_vecs + * @tparam T element type. + * @tparam IdxT type of the indices in the source source_vecs + * @tparam LabelT label type * * @param[in] labels device pointer to the cluster ids for each row [n_rows] * @param[in] list_offsets device pointer to the cluster offsets in the output (index) [n_lists] @@ -60,8 +61,8 @@ using namespace raft::spatial::knn::detail; // NOLINT * @param veclen size of vectorized loads/stores; must satisfy `dim % veclen == 0`. * */ -template -__global__ void build_index_kernel(const uint32_t* labels, +template +__global__ void build_index_kernel(const LabelT* labels, const IdxT* list_offsets, const T* source_vecs, const IdxT* source_ixs, @@ -110,6 +111,8 @@ inline auto extend(const handle_t& handle, const IdxT* new_indices, IdxT n_rows) -> index { + using LabelT = uint32_t; + auto stream = handle.get_stream(); auto n_lists = orig_index.n_lists(); auto dim = orig_index.dim(); @@ -119,16 +122,16 @@ inline auto extend(const handle_t& handle, RAFT_EXPECTS(new_indices != nullptr || orig_index.size() == 0, "You must pass data indices when the index is non-empty."); - rmm::device_uvector new_labels(n_rows, stream); - kmeans::predict(handle, - orig_index.centers().data_handle(), - n_lists, - dim, - new_vectors, - n_rows, - new_labels.data(), - orig_index.metric(), - stream); + rmm::device_uvector new_labels(n_rows, stream); + kmeans::predict(handle, + orig_index.centers().data_handle(), + n_lists, + dim, + new_vectors, + n_rows, + new_labels.data(), + orig_index.metric(), + stream); index ext_index(handle, orig_index.metric(), n_lists, dim); @@ -206,6 +209,7 @@ inline auto extend(const handle_t& handle, // Precompute the centers vector norms for L2Expanded distance if (ext_index.center_norms().has_value()) { + // todo(lsugy): use other prim and remove this one utils::dots_along_rows(n_lists, dim, ext_index.centers().data_handle(), @@ -250,15 +254,15 @@ inline auto build( n_rows_train, cudaMemcpyDefault, stream)); - kmeans::build_hierarchical(handle, - params.kmeans_n_iters, - index.dim(), - trainset.data(), - n_rows_train, - index.centers().data_handle(), - index.n_lists(), - index.metric(), - stream); + kmeans::build_hierarchical(handle, + params.kmeans_n_iters, + index.dim(), + trainset.data(), + n_rows_train, + index.centers().data_handle(), + index.n_lists(), + index.metric(), + stream); } // add the data if necessary diff --git a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh index 770530b77c..f37bccaadb 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_flat_search.cuh @@ -1098,15 +1098,16 @@ void search_impl(const handle_t& handle, float alpha = 1.0f; float beta = 0.0f; + // todo(lsugy): raft distance? (if performance is similar/better than gemm) if (index.metric() == raft::distance::DistanceType::L2Expanded) { alpha = -2.0f; beta = 1.0f; utils::dots_along_rows( n_queries, index.dim(), converted_queries_ptr, query_norm_dev.data(), stream); utils::outer_add(query_norm_dev.data(), - n_queries, + (IdxT)n_queries, index.center_norms()->data_handle(), - index.n_lists(), + (IdxT)index.n_lists(), distance_buffer_dev.data(), stream); RAFT_LOG_TRACE_VEC(index.center_norms()->data_handle(), std::min(20, index.dim())); diff --git a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh index 5a146c18fe..823f2d98a0 100644 --- a/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh +++ b/cpp/include/raft/spatial/knn/detail/ivf_pq_build.cuh @@ -225,7 +225,8 @@ void select_residuals(const handle_t& handle, { auto stream = handle.get_stream(); rmm::device_uvector tmp(n_rows * dim, stream, device_memory); - utils::copy_selected(n_rows, dim, dataset, row_ids, dim, tmp.data(), dim, stream); + utils::copy_selected( + n_rows, (IdxT)dim, dataset, row_ids, (IdxT)dim, tmp.data(), (IdxT)dim, stream); raft::matrix::linewiseOp( tmp.data(), @@ -482,12 +483,12 @@ void train_per_subset(const handle_t& handle, // Get the rotated cluster centers for each training vector. // This will be subtracted from the input vectors afterwards. utils::copy_selected(n_rows, - index.pq_len(), + (IdxT)index.pq_len(), index.centers_rot().data_handle() + index.pq_len() * j, labels, - index.rot_dim(), + (IdxT)index.rot_dim(), sub_trainset.data(), - index.pq_len(), + (IdxT)index.pq_len(), stream); // sub_trainset is the slice of: rotate(trainset) - centers_rot @@ -870,13 +871,13 @@ inline auto extend(const handle_t& handle, new_cluster_size, stream); } else { - utils::copy_selected(new_cluster_size, - 1, + utils::copy_selected((IdxT)new_cluster_size, + (IdxT)1, new_indices, new_data_indices.data() + new_cluster_offsets.data()[k], - 1, + (IdxT)1, ext_indices + ext_cluster_offsets[l] + old_cluster_size, - 1, + (IdxT)1, stream); } } diff --git a/cpp/test/spatial/ann_ivf_flat.cu b/cpp/test/spatial/ann_ivf_flat.cu index 241c4f6547..01af7ea0bd 100644 --- a/cpp/test/spatial/ann_ivf_flat.cu +++ b/cpp/test/spatial/ann_ivf_flat.cu @@ -32,6 +32,10 @@ #include +#if defined RAFT_DISTANCE_COMPILED && defined RAFT_NN_COMPILED +#include +#endif + #include #include #include @@ -110,7 +114,7 @@ class AnnIVFFlatTest : public ::testing::TestWithParam> { &index, dynamic_cast(&ivfParams), ps.metric, - 0, + (IdxT)0, database.data(), ps.num_db_vecs, ps.dim);