From ca2591443f627aeb0eab0a0a804c4fe719cc974c Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 23 Apr 2021 12:10:08 -0400 Subject: [PATCH 01/55] Stubbing out rbc --- .../raft/spatial/knn/detail/ball_cover.cuh | 126 ++++++++++++++++++ 1 file changed, 126 insertions(+) create mode 100644 cpp/include/raft/spatial/knn/detail/ball_cover.cuh diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh new file mode 100644 index 0000000000..f6896e5293 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -0,0 +1,126 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "brute_force_knn.cuh" +#include + +#include +#include +#include + +#include + +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + + +struct NNComp { + template + __host__ __device__ bool operator()(const one &t1, const two &t2) { + // sort first by each sample's reference landmark, + if (thrust::get<0>(t1) < thrust::get<0>(t2)) return true; + if (thrust::get<0>(t1) > thrust::get<0>(t2)) return false; + + // then by closest neighbor, + return thrust::get<1>(t1) < thrust::get<1>(t2); + } +}; + + + +/** + * Random ball cover algorithm uses the triangle inequality + * (which can be used for any valid metric or distance + * that follows the triangle inequality) + * @tparam value_idx + * @tparam value_t + */ +template +void random_ball_cover(const raft::handle_t &handle, const value_t *X, + value_idx m, value_idx n, int k, value_idx *inds, + value_t *dists) { + /** + * 1. Randomly sample sqrt(n) points from X + */ + + rmm::device_uvector R_knn_inds(k * m, handle.get_stream()); + rmm::device_uvector R_knn_dists(k * m, handle.get_stream()); + + value_idx n_samples = int(sqrt(m)); + + rmm::device_uvector R_indices(n_samples, handle.get_stream()); + rmm::device_uvector R(n_samples * n, handle.get_stream()); + raft::random::uniformInt(R_indices.data(), n_samples, 0, m-1, handle.get_stream()); + + raft::matrix::copyRows(X, m, n, R.data(), R_indices.data(), n_samples, + handle.get_stream(), true); + + /** + * 2. Perform knn = bfknn(X, R, k) + */ + brute_force_knn_impl({X}, {m}, n, R.data(), n_samples, R_knn_inds.data(), R_knn_dists.data(), k, + handle.get_device_allocator(), handle.get_stream()); + + /** + * 3. Create L_r = knn[:,0].T (CSR) + * + * Slice closest neighboring R + * Secondary sort by (R_knn_inds, R_knn_dists) + */ + rmm::device_uvector R_1nn_inds(m, handle.get_stream()); + rmm::device_uvector R_1nn_dists(m, handle.get_stream()); + rmm::device_uvector R_1nn_cols(m, handle.get_stream()); + + raft::matrix::sliceMatrix(R_knn_inds.data(), m, k, R_1nn_inds.data(), 0, + 1, m, 2, handle.get_stream()); + raft::matrix::sliceMatrix(R_knn_dists.data(), m, k, R_1nn_dists.data(), 0,https://arxiv.org/search/cs?searchtype=author&query=Domingos%2C+P + 1, m, 2, handle.get_stream()); + + thrust::sequence(thrust::cuda::par.on(handle.get_stream()), R_1nn_cols.data(), + R_1nn_cols.data()+m, 1); + + auto keys = thrust::make_zip_iterator(thrust::make_tuple( + R_1nn_inds.data(), R_1nn_dists.data())); + auto vals = thrust::make_zip_iterator(thrust::make_tuple(R_1nn_cols.data())); + + // group neighborhoods for each reference landmark and sort each group by distance + thrust::sort_by_key(thrust::cuda::par.on(stream), keys, keys + n_rows, vals, + NNComp()); + + // convert to CSR for fast lookup + rmm::device_uvector R_indptr(n_samples, handle.get_stream()); + raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data(), m, R_indptr.data(), n_samples+1, + handle.get_device_allocator(), handle.get_stream()); + + /** + * 4. Perform k-select over original KNN, using L_r to filter distances + * + * a. Map 1 row to each warp/block + * b. Add closest R points to heap + * c. Iterate through batches of R, having each thread in the warp load a set + * of distances y from R and marking the distance to be computed between x, y only + * if current knn[k].distance >= d(x_i, R_k) + d(R_k, y) + */ + +} +}; +}; +}; +}; \ No newline at end of file From ae16d0861f0d51f5edd718fded4e2f8e82424958 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 16 Jun 2021 14:03:50 -0400 Subject: [PATCH 02/55] Fixing style --- .../raft/spatial/knn/detail/ball_cover.cuh | 43 ++++++++++--------- 1 file changed, 22 insertions(+), 21 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index f6896e5293..f270fcafce 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -14,11 +14,11 @@ * limitations under the License. */ -#include "brute_force_knn.cuh" #include +#include "brute_force_knn.cuh" -#include #include +#include #include #include @@ -30,7 +30,6 @@ namespace spatial { namespace knn { namespace detail { - struct NNComp { template __host__ __device__ bool operator()(const one &t1, const two &t2) { @@ -43,8 +42,6 @@ struct NNComp { } }; - - /** * Random ball cover algorithm uses the triangle inequality * (which can be used for any valid metric or distance @@ -67,7 +64,8 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, rmm::device_uvector R_indices(n_samples, handle.get_stream()); rmm::device_uvector R(n_samples * n, handle.get_stream()); - raft::random::uniformInt(R_indices.data(), n_samples, 0, m-1, handle.get_stream()); + raft::random::uniformInt(R_indices.data(), n_samples, 0, m - 1, + handle.get_stream()); raft::matrix::copyRows(X, m, n, R.data(), R_indices.data(), n_samples, handle.get_stream(), true); @@ -75,8 +73,9 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, /** * 2. Perform knn = bfknn(X, R, k) */ - brute_force_knn_impl({X}, {m}, n, R.data(), n_samples, R_knn_inds.data(), R_knn_dists.data(), k, - handle.get_device_allocator(), handle.get_stream()); + brute_force_knn_impl({X}, {m}, n, R.data(), n_samples, R_knn_inds.data(), + R_knn_dists.data(), k, handle.get_device_allocator(), + handle.get_stream()); /** * 3. Create L_r = knn[:,0].T (CSR) @@ -88,16 +87,18 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, rmm::device_uvector R_1nn_dists(m, handle.get_stream()); rmm::device_uvector R_1nn_cols(m, handle.get_stream()); - raft::matrix::sliceMatrix(R_knn_inds.data(), m, k, R_1nn_inds.data(), 0, - 1, m, 2, handle.get_stream()); - raft::matrix::sliceMatrix(R_knn_dists.data(), m, k, R_1nn_dists.data(), 0,https://arxiv.org/search/cs?searchtype=author&query=Domingos%2C+P - 1, m, 2, handle.get_stream()); + raft::matrix::sliceMatrix(R_knn_inds.data(), m, k, R_1nn_inds.data(), 0, 1, m, + 2, handle.get_stream()); + raft::matrix::sliceMatrix( + R_knn_dists.data(), m, k, R_1nn_dists.data(), 0, https + : //arxiv.org/search/cs?searchtype=author&query=Domingos%2C+P + 1, m, 2, handle.get_stream()); thrust::sequence(thrust::cuda::par.on(handle.get_stream()), R_1nn_cols.data(), - R_1nn_cols.data()+m, 1); + R_1nn_cols.data() + m, 1); - auto keys = thrust::make_zip_iterator(thrust::make_tuple( - R_1nn_inds.data(), R_1nn_dists.data())); + auto keys = thrust::make_zip_iterator( + thrust::make_tuple(R_1nn_inds.data(), R_1nn_dists.data())); auto vals = thrust::make_zip_iterator(thrust::make_tuple(R_1nn_cols.data())); // group neighborhoods for each reference landmark and sort each group by distance @@ -106,7 +107,8 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, // convert to CSR for fast lookup rmm::device_uvector R_indptr(n_samples, handle.get_stream()); - raft::sparse::convert::sorted_coo_to_csr(R_1nn_inds.data(), m, R_indptr.data(), n_samples+1, + raft::sparse::convert::sorted_coo_to_csr( + R_1nn_inds.data(), m, R_indptr.data(), n_samples + 1, handle.get_device_allocator(), handle.get_stream()); /** @@ -118,9 +120,8 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, * of distances y from R and marking the distance to be computed between x, y only * if current knn[k].distance >= d(x_i, R_k) + d(R_k, y) */ - } -}; -}; -}; -}; \ No newline at end of file +}; // namespace detail +}; // namespace knn +}; // namespace spatial +}; // namespace raft \ No newline at end of file From af9520fa8fab4273a0502d8a10552c4588d99002 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Fri, 18 Jun 2021 15:35:25 -0400 Subject: [PATCH 03/55] It's coming along. Using warp-select for each closest R --- .../raft/spatial/knn/detail/ball_cover.cuh | 253 ++++++++++++++++-- .../knn/detail/knn_brute_force_faiss.cuh | 22 +- cpp/include/raft/spatial/knn/knn.hpp | 8 + cpp/test/CMakeLists.txt | 1 + cpp/test/spatial/ball_cover.cu | 123 +++++++++ cpp/test/spatial/knn.cu | 9 - 6 files changed, 373 insertions(+), 43 deletions(-) create mode 100644 cpp/test/spatial/ball_cover.cu diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index f270fcafce..11e95dc63c 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -14,14 +14,24 @@ * limitations under the License. */ +#pragma once + #include -#include "brute_force_knn.cuh" +#include "haversine_distance.cuh" +#include "knn_brute_force_faiss.cuh" + +#include #include #include #include #include +#include + +#include +#include +#include #include @@ -42,6 +52,123 @@ struct NNComp { } }; +/** + * Kernel for more narrow data sizes (n_cols <= 32) + * @tparam value_idx + * @tparam value_t + * @tparam warp_q + * @tparam thread_q + * @tparam tpb + * @tparam value_idx + * @tparam value_t + * @param R_knn_inds + * @param R_knn_dists + * @param m + * @param k + * @param R_indptr + * @param R_1nn_cols + * @param R_1nn_dists + */ +template +__global__ void rbc_kernel(const value_t *X, const value_idx n_cols, + const value_idx *R_knn_inds, + const value_t *R_knn_dists, value_idx m, int k, + const value_idx *R_indptr, + const value_idx *R_1nn_cols, + const value_t *R_1nn_dists, value_idx *out_inds, + value_t *out_dists, value_idx *sampled_inds_map) { + int row = blockIdx.x; + + const value_t *x_ptr = X + (n_cols * row); + value_t x1 = x_ptr[0]; + value_t x2 = x_ptr[1]; + + constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + + // Each warp works on 1 R + faiss::gpu::WarpSelect, warp_q, thread_q, tpb> + heap(faiss::gpu::Limits::getMax(), -1, k); + + // Grid is exactly sized to rows available + value_t min_R_dist = R_knn_dists[row * k]; + value_idx min_R_ind = R_knn_inds[row * k]; + + /** + * First add distances for k closest neighbors of R + * to the heap + */ + // Start iterating through elements of each set from closest R elements, + // determining if the distance could even potentially be in the heap. + + value_idx cur_k = 0; + + // just doing Rs for the closest k for now + value_t cur_R_dist = R_knn_dists[row * k + cur_k]; + value_idx cur_R_ind = R_knn_inds[row * k + cur_k]; + + // The whole warp should iterate through the elements in the current R + value_idx R_start_offset = R_indptr[cur_R_ind]; + value_idx R_stop_offset = R_indptr[cur_R_ind + 1]; + + value_idx R_size = R_stop_offset - R_start_offset; + + int limit = faiss::gpu::utils::roundDown(R_size, faiss::gpu::kWarpSize); + + int i = threadIdx.x; + for (; i < limit; i += tpb) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_idx cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + printf("row=%d, warpKTopDist=%f, numVals=%d\n", row, heap.warpKTop, + heap.numVals); + + if (i < k || heap.warpKTop >= cur_candidate_dist + cur_R_dist) { + const value_t *y_ptr = X + (n_cols * cur_candidate_ind); + value_t y1 = y_ptr[0]; + value_t y2 = y_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + printf( + "row=%d, R=%d, R_stop_offset=%d, R_start_offset=%d, candidate_ind=%d, " + "dist=%f\n", + row, cur_R_ind, R_stop_offset, R_start_offset, cur_candidate_ind, dist); + + heap.addThreadQ(dist, cur_candidate_ind); + } + + heap.checkThreadQ(); + } + + if (i < R_size) { + value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; + value_idx cur_candidate_dist = R_1nn_dists[R_start_offset + i]; + + printf("row=%d, warpKTopDist=%f, numVals=%d\n", row, heap.warpKTop, + heap.numVals); + + if (i < k || heap.warpKTop >= cur_candidate_dist + cur_R_dist) { + const value_t *y_ptr = X + (n_cols * cur_candidate_ind); + value_t y1 = y_ptr[0]; + value_t y2 = y_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + printf( + "row=%d, R=%d, R_stop_offset=%d, R_start_offset=%d, candidate_ind=%d, " + "dist=%f\n", + row, cur_R_ind, R_stop_offset, R_start_offset, cur_candidate_ind, dist); + + heap.addThreadQ(dist, cur_candidate_ind); + } + } + + heap.reduce(); + heap.writeOut(out_dists + (row * k), out_inds + (row * k), k); +} + /** * Random ball cover algorithm uses the triangle inequality * (which can be used for any valid metric or distance @@ -49,33 +176,74 @@ struct NNComp { * @tparam value_idx * @tparam value_t */ -template +template void random_ball_cover(const raft::handle_t &handle, const value_t *X, value_idx m, value_idx n, int k, value_idx *inds, - value_t *dists) { + value_t *dists, value_idx n_samples = -1) { + auto exec_policy = rmm::exec_policy(handle.get_stream()); /** * 1. Randomly sample sqrt(n) points from X */ + printf("k=%d\n", k); + rmm::device_uvector R_knn_inds_64b(k * m, handle.get_stream()); + rmm::device_uvector R_knn_inds(k * m, handle.get_stream()); rmm::device_uvector R_knn_dists(k * m, handle.get_stream()); - value_idx n_samples = int(sqrt(m)); + n_samples = n_samples < 1 ? int(sqrt(m)) : n_samples; + + ASSERT(n_samples >= k, "number of landmark samples must be >= k"); + + printf("Sampling %d points\n", n_samples); rmm::device_uvector R_indices(n_samples, handle.get_stream()); rmm::device_uvector R(n_samples * n, handle.get_stream()); - raft::random::uniformInt(R_indices.data(), n_samples, 0, m - 1, - handle.get_stream()); + + rmm::device_uvector R_1nn_cols(m, handle.get_stream()); + rmm::device_uvector R_1nn_ones(m, handle.get_stream()); + + thrust::fill(exec_policy, R_1nn_ones.data(), + R_1nn_ones.data() + R_1nn_ones.size(), 1.0); + + rmm::device_uvector R_1nn_cols2(m, handle.get_stream()); + + thrust::sequence(thrust::cuda::par.on(handle.get_stream()), R_1nn_cols.data(), + R_1nn_cols.data() + m, 0); + + auto rng = raft::random::Rng(12345); + rng.sampleWithoutReplacement(handle, R_indices.data(), R_1nn_cols2.data(), + R_1nn_cols.data(), R_1nn_ones.data(), n_samples, + m, handle.get_stream()); raft::matrix::copyRows(X, m, n, R.data(), R_indices.data(), n_samples, handle.get_stream(), true); + raft::print_device_vector("sampled indices", R_indices.data(), + R_indices.size(), std::cout); + raft::print_device_vector("sampled rows", R.data(), R.size(), std::cout); + /** * 2. Perform knn = bfknn(X, R, k) */ - brute_force_knn_impl({X}, {m}, n, R.data(), n_samples, R_knn_inds.data(), - R_knn_dists.data(), k, handle.get_device_allocator(), - handle.get_stream()); + printf("Performing bfknn of landmarks\n"); + std::vector input = {R.data()}; + std::vector sizes = {n_samples}; + + brute_force_knn_impl( + input, sizes, n, const_cast(X), m, R_knn_inds_64b.data(), + R_knn_dists.data(), k, handle.get_device_allocator(), handle.get_stream(), + nullptr, 0, (bool)true, (bool)true, nullptr, + raft::distance::DistanceType::Haversine); + + thrust::transform(exec_policy, R_knn_inds_64b.data(), + R_knn_inds_64b.data() + R_knn_inds_64b.size(), + R_knn_inds.data(), [] __device__(int64_t i) { return i; }); + + raft::print_device_vector("R_knn_inds", R_knn_inds.data(), R_knn_inds.size(), + std::cout); + raft::print_device_vector("R_knn_dists", R_knn_dists.data(), + R_knn_dists.size(), std::cout); /** * 3. Create L_r = knn[:,0].T (CSR) @@ -83,27 +251,52 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, * Slice closest neighboring R * Secondary sort by (R_knn_inds, R_knn_dists) */ + + printf("Building representative lists for landmarks\n"); rmm::device_uvector R_1nn_inds(m, handle.get_stream()); rmm::device_uvector R_1nn_dists(m, handle.get_stream()); - rmm::device_uvector R_1nn_cols(m, handle.get_stream()); - raft::matrix::sliceMatrix(R_knn_inds.data(), m, k, R_1nn_inds.data(), 0, 1, m, - 2, handle.get_stream()); - raft::matrix::sliceMatrix( - R_knn_dists.data(), m, k, R_1nn_dists.data(), 0, https - : //arxiv.org/search/cs?searchtype=author&query=Domingos%2C+P - 1, m, 2, handle.get_stream()); + value_idx *R_1nn_inds_ptr = R_1nn_inds.data(); + value_t *R_1nn_dists_ptr = R_1nn_dists.data(); + value_idx *R_knn_inds_ptr = R_knn_inds.data(); + value_t *R_knn_dists_ptr = R_knn_dists.data(); - thrust::sequence(thrust::cuda::par.on(handle.get_stream()), R_1nn_cols.data(), - R_1nn_cols.data() + m, 1); + auto idxs = thrust::make_counting_iterator(0); + thrust::for_each(exec_policy, idxs, idxs + m, [=] __device__(value_idx i) { + R_1nn_inds_ptr[i] = R_knn_inds_ptr[i * k]; + R_1nn_dists_ptr[i] = R_knn_dists_ptr[i * k]; + }); + // I think this might be assuming col-major? + // raft::matrix::sliceMatrix(R_knn_inds.data(), m, k, R_1nn_inds.data(), 0, 0, + // m+1, 1, handle.get_stream()); + // raft::matrix::sliceMatrix( + // R_knn_dists.data(), m, k, R_1nn_dists.data(), 0, + // 0, m+1, 1, handle.get_stream()); + // auto keys = thrust::make_zip_iterator( thrust::make_tuple(R_1nn_inds.data(), R_1nn_dists.data())); auto vals = thrust::make_zip_iterator(thrust::make_tuple(R_1nn_cols.data())); + raft::print_device_vector("R_1nn_inds", R_1nn_inds.data(), R_1nn_inds.size(), + std::cout); + raft::print_device_vector("R_1nn_cols", R_1nn_cols.data(), R_1nn_cols.size(), + std::cout); + + raft::print_device_vector("R_1nn_dists", R_1nn_dists.data(), + R_1nn_dists.size(), std::cout); + // group neighborhoods for each reference landmark and sort each group by distance - thrust::sort_by_key(thrust::cuda::par.on(stream), keys, keys + n_rows, vals, - NNComp()); + thrust::sort_by_key(thrust::cuda::par.on(handle.get_stream()), keys, + keys + R_1nn_inds.size(), vals, NNComp()); + + raft::print_device_vector("R_1nn_inds_sorted", R_1nn_inds.data(), + R_1nn_inds.size(), std::cout); + raft::print_device_vector("R_1nn_cols_sorted", R_1nn_cols.data(), + R_1nn_cols.size(), std::cout); + + raft::print_device_vector("R_1nn_dists_sorted", R_1nn_dists.data(), + R_1nn_dists.size(), std::cout); // convert to CSR for fast lookup rmm::device_uvector R_indptr(n_samples, handle.get_stream()); @@ -111,16 +304,30 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, R_1nn_inds.data(), m, R_indptr.data(), n_samples + 1, handle.get_device_allocator(), handle.get_stream()); + raft::print_device_vector("R_1nn_indptr", R_indptr.data(), R_indptr.size(), + std::cout); + /** * 4. Perform k-select over original KNN, using L_r to filter distances * * a. Map 1 row to each warp/block - * b. Add closest R points to heap + * b. Add closest k R points to heap * c. Iterate through batches of R, having each thread in the warp load a set - * of distances y from R and marking the distance to be computed between x, y only - * if current knn[k].distance >= d(x_i, R_k) + d(R_k, y) + * of distances y from R (only if d(q, r) < 3 * distance to closest r) and marking the distance to be computed between x, y only + * if knn[k].distance >= d(x_i, R_k) + d(R_k, y) */ + + printf("Performing final bfknn\n"); + + rbc_kernel<<>>( + X, n, R_knn_inds.data(), R_knn_dists.data(), m, k, R_indptr.data(), + R_1nn_cols.data(), R_1nn_dists.data(), inds, dists, R_indices.data()); + + // Thoughts: + // For n_cols < 32, we could probably just have each thread compute the distance + // For n_cols >= 32, we could probably have full warps compute the distances } + }; // namespace detail }; // namespace knn }; // namespace spatial diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 75226299e6..3b7c0aacd4 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -226,30 +226,30 @@ inline faiss::MetricType build_faiss_metric( * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm */ -template +template void brute_force_knn_impl(std::vector &input, std::vector &sizes, IntType D, float *search_items, IntType n, - int64_t *res_I, float *res_D, IntType k, + IdxType *res_I, float *res_D, IntType k, std::shared_ptr allocator, cudaStream_t userStream, cudaStream_t *internalStreams = nullptr, int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true, - std::vector *translations = nullptr, + std::vector *translations = nullptr, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, float metricArg = 0) { ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); - std::vector *id_ranges; + std::vector *id_ranges; if (translations == nullptr) { // If we don't have explicit translations // for offsets of the indices, build them // from the local partitions - id_ranges = new std::vector(); - int64_t total_n = 0; - for (size_t i = 0; i < input.size(); i++) { + id_ranges = new std::vector(); + IdxType total_n = 0; + for (IdxType i = 0; i < input.size(); i++) { id_ranges->push_back(total_n); total_n += sizes[i]; } @@ -275,16 +275,16 @@ void brute_force_knn_impl(std::vector &input, std::vector &sizes, int device; CUDA_CHECK(cudaGetDevice(&device)); - raft::mr::device::buffer trans(allocator, userStream, + raft::mr::device::buffer trans(allocator, userStream, id_ranges->size()); raft::update_device(trans.data(), id_ranges->data(), id_ranges->size(), userStream); raft::mr::device::buffer all_D(allocator, userStream, 0); - raft::mr::device::buffer all_I(allocator, userStream, 0); + raft::mr::device::buffer all_I(allocator, userStream, 0); float *out_D = res_D; - int64_t *out_I = res_I; + IdxType *out_I = res_I; if (input.size() > 1) { all_D.resize(input.size() * k * n, userStream); @@ -299,7 +299,7 @@ void brute_force_knn_impl(std::vector &input, std::vector &sizes, for (size_t i = 0; i < input.size(); i++) { float *out_d_ptr = out_D + (i * k * n); - int64_t *out_i_ptr = out_I + (i * k * n); + IdxType *out_i_ptr = out_I + (i * k * n); cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index a3a1972c13..dd6435e67e 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -16,6 +16,7 @@ #pragma once +#include "detail/ball_cover.cuh" #include "detail/knn_brute_force_faiss.cuh" #include @@ -78,6 +79,13 @@ inline void brute_force_knn( rowMajorQuery, translations, metric, metric_arg); } +template +inline void random_ball_cover(const raft::handle_t &handle, const value_t *X, + value_idx m, value_idx n, int k, value_idx *inds, + value_t *dists) { + detail::random_ball_cover(handle, X, m, n, k, inds, dists); +} + } // namespace knn } // namespace spatial } // namespace raft diff --git a/cpp/test/CMakeLists.txt b/cpp/test/CMakeLists.txt index 6496ac26c6..bd791f8b12 100644 --- a/cpp/test/CMakeLists.txt +++ b/cpp/test/CMakeLists.txt @@ -80,6 +80,7 @@ add_executable(test_raft test/sparse/symmetrize.cu test/spatial/knn.cu test/spatial/haversine.cu + test/spatial/ball_cover.cu test/spectral_matrix.cu test/stats/mean.cu test/stats/mean_center.cu diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu new file mode 100644 index 0000000000..1997221ae4 --- /dev/null +++ b/cpp/test/spatial/ball_cover.cu @@ -0,0 +1,123 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "../test_utils.h" + +namespace raft { +namespace spatial { +namespace knn { + +template +class BallCoverKNNTest : public ::testing::Test { + protected: + void basicTest() { + auto alloc = std::make_shared(); + + raft::handle_t handle; + + // Allocate input + raft::allocate(d_train_inputs, n * d); + + // Allocate reference arrays + raft::allocate(d_ref_I, n * n); + raft::allocate(d_ref_D, n * n); + + // Allocate predicted arrays + raft::allocate(d_pred_I, n * n); + raft::allocate(d_pred_D, n * n); + + // make testdata on host + std::vector h_train_inputs = { + 0.71113885, -1.29215058, 0.59613176, -2.08048115, + 0.74932804, -1.33634042, 0.51486728, -1.65962873, + 0.53154002, -1.47049808, 0.72891737, -1.54095137}; + + h_train_inputs.resize(n); + raft::update_device(d_train_inputs, h_train_inputs.data(), n * d, 0); + + std::vector h_res_D = { + 0., 0.05041587, 0.18767063, 0.23048252, 0.35749438, 0.62925595, + 0., 0.36575755, 0.44288665, 0.5170737, 0.59501296, 0.62925595, + 0., 0.05041587, 0.152463, 0.2426416, 0.34925285, 0.59501296, + 0., 0.16461092, 0.2345792, 0.34925285, 0.35749438, 0.36575755, + 0., 0.16461092, 0.20535265, 0.23048252, 0.2426416, 0.5170737, + 0., 0.152463, 0.18767063, 0.20535265, 0.2345792, 0.44288665}; + h_res_D.resize(n * n); + raft::update_device(d_ref_D, h_res_D.data(), n * n, 0); + + std::vector h_res_I = {0, 2, 5, 4, 3, 1, 1, 3, 5, 4, 2, 0, + 2, 0, 5, 4, 3, 1, 3, 4, 5, 2, 0, 1, + 4, 3, 5, 0, 2, 1, 5, 2, 0, 4, 3, 1}; + h_res_I.resize(n * n); + raft::update_device(d_ref_I, h_res_I.data(), n * n, 0); + + std::vector input_vec = {d_train_inputs}; + std::vector sizes_vec = {n}; + + raft::spatial::knn::detail::random_ball_cover(handle, d_train_inputs, n, d, + k, d_pred_I, d_pred_D, s); + + CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); + + raft::print_device_vector("inds", d_pred_I, n * k, std::cout); + raft::print_device_vector("dists", d_pred_D, n * k, std::cout); + } + + void SetUp() override { basicTest(); } + + void TearDown() override { + CUDA_CHECK(cudaFree(d_train_inputs)); + CUDA_CHECK(cudaFree(d_pred_I)); + CUDA_CHECK(cudaFree(d_pred_D)); + CUDA_CHECK(cudaFree(d_ref_I)); + CUDA_CHECK(cudaFree(d_ref_D)); + } + + protected: + value_t *d_train_inputs; + + int n = 6; + int d = 2; + + int k = 2; + + int s = 2; + + value_idx *d_pred_I; + value_t *d_pred_D; + + value_idx *d_ref_I; + value_t *d_ref_D; +}; + +typedef BallCoverKNNTest BallCoverKNNTestF; + +TEST_F(BallCoverKNNTestF, Fit) { + ASSERT_TRUE(raft::devArrMatch(d_ref_D, d_pred_D, n * n, + raft::CompareApprox(1e-3))); + ASSERT_TRUE( + raft::devArrMatch(d_ref_I, d_pred_I, n * n, raft::Compare())); +} + +} // namespace knn +} // namespace spatial +} // namespace raft diff --git a/cpp/test/spatial/knn.cu b/cpp/test/spatial/knn.cu index 2b1ef89f7a..a521ce6ace 100644 --- a/cpp/test/spatial/knn.cu +++ b/cpp/test/spatial/knn.cu @@ -81,15 +81,6 @@ class KNNTest : public ::testing::TestWithParam { build_expected_output<<>>( expected_labels_, rows_, k_, search_labels_); - raft::print_device_vector("Output indices: ", indices_, rows_ * k_, - std::cout); - raft::print_device_vector("Output distances: ", distances_, rows_ * k_, - std::cout); - raft::print_device_vector("Output labels: ", actual_labels_, rows_ * k_, - std::cout); - raft::print_device_vector("Expected labels: ", expected_labels_, rows_ * k_, - std::cout); - ASSERT_TRUE(devArrMatch(expected_labels_, actual_labels_, rows_ * k_, raft::Compare())); } From c4b27d6aedbca7e18c5d38a98f85db3efa3d2a50 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Mon, 21 Jun 2021 13:15:24 -0400 Subject: [PATCH 04/55] checkjig in --- .../raft/spatial/knn/detail/ball_cover.cuh | 121 +++++++++--------- .../knn/detail/knn_brute_force_faiss.cuh | 1 + cpp/include/raft/spatial/knn/knn.hpp | 4 +- cpp/test/spatial/ball_cover.cu | 4 +- 4 files changed, 67 insertions(+), 63 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh index 11e95dc63c..5497546a47 100644 --- a/cpp/include/raft/spatial/knn/detail/ball_cover.cuh +++ b/cpp/include/raft/spatial/knn/detail/ball_cover.cuh @@ -78,14 +78,13 @@ __global__ void rbc_kernel(const value_t *X, const value_idx n_cols, const value_idx *R_1nn_cols, const value_t *R_1nn_dists, value_idx *out_inds, value_t *out_dists, value_idx *sampled_inds_map) { - int row = blockIdx.x; + int row = blockIdx.x / k; + int cur_k = blockIdx.x % k; const value_t *x_ptr = X + (n_cols * row); value_t x1 = x_ptr[0]; value_t x2 = x_ptr[1]; - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; - // Each warp works on 1 R faiss::gpu::WarpSelect, warp_q, thread_q, tpb> @@ -102,8 +101,6 @@ __global__ void rbc_kernel(const value_t *X, const value_idx n_cols, // Start iterating through elements of each set from closest R elements, // determining if the distance could even potentially be in the heap. - value_idx cur_k = 0; - // just doing Rs for the closest k for now value_t cur_R_dist = R_knn_dists[row * k + cur_k]; value_idx cur_R_ind = R_knn_inds[row * k + cur_k]; @@ -121,9 +118,9 @@ __global__ void rbc_kernel(const value_t *X, const value_idx n_cols, value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; value_idx cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - printf("row=%d, warpKTopDist=%f, numVals=%d\n", row, heap.warpKTop, - heap.numVals); - +// printf("row=%d, warpKTopDist=%f, numVals=%d\n", row, heap.warpKTop, +// heap.numVals); +// if (i < k || heap.warpKTop >= cur_candidate_dist + cur_R_dist) { const value_t *y_ptr = X + (n_cols * cur_candidate_ind); value_t y1 = y_ptr[0]; @@ -131,11 +128,11 @@ __global__ void rbc_kernel(const value_t *X, const value_idx n_cols, value_t dist = compute_haversine(x1, y1, x2, y2); - printf( - "row=%d, R=%d, R_stop_offset=%d, R_start_offset=%d, candidate_ind=%d, " - "dist=%f\n", - row, cur_R_ind, R_stop_offset, R_start_offset, cur_candidate_ind, dist); - +// printf( +// "row=%d, R=%d, R_stop_offset=%d, R_start_offset=%d, candidate_ind=%d, " +// "dist=%f\n", +// row, cur_R_ind, R_stop_offset, R_start_offset, cur_candidate_ind, dist); +// heap.addThreadQ(dist, cur_candidate_ind); } @@ -145,9 +142,9 @@ __global__ void rbc_kernel(const value_t *X, const value_idx n_cols, if (i < R_size) { value_idx cur_candidate_ind = R_1nn_cols[R_start_offset + i]; value_idx cur_candidate_dist = R_1nn_dists[R_start_offset + i]; - - printf("row=%d, warpKTopDist=%f, numVals=%d\n", row, heap.warpKTop, - heap.numVals); +// +// printf("row=%d, warpKTopDist=%f, numVals=%d\n", row, heap.warpKTop, +// heap.numVals); if (i < k || heap.warpKTop >= cur_candidate_dist + cur_R_dist) { const value_t *y_ptr = X + (n_cols * cur_candidate_ind); @@ -155,18 +152,20 @@ __global__ void rbc_kernel(const value_t *X, const value_idx n_cols, value_t y2 = y_ptr[1]; value_t dist = compute_haversine(x1, y1, x2, y2); - - printf( - "row=%d, R=%d, R_stop_offset=%d, R_start_offset=%d, candidate_ind=%d, " - "dist=%f\n", - row, cur_R_ind, R_stop_offset, R_start_offset, cur_candidate_ind, dist); +// +// printf( +// "row=%d, R=%d, R_stop_offset=%d, R_start_offset=%d, candidate_ind=%d, " +// "dist=%f\n", +// row, cur_R_ind, R_stop_offset, R_start_offset, cur_candidate_ind, dist); heap.addThreadQ(dist, cur_candidate_ind); } } heap.reduce(); - heap.writeOut(out_dists + (row * k), out_inds + (row * k), k); + + value_idx cur_idx = (row * k * k) + (cur_k * k); + heap.writeOut(out_dists + cur_idx, out_inds + cur_idx, k); } /** @@ -185,17 +184,21 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, * 1. Randomly sample sqrt(n) points from X */ - printf("k=%d\n", k); +// printf("k=%d\n", k); rmm::device_uvector R_knn_inds_64b(k * m, handle.get_stream()); rmm::device_uvector R_knn_inds(k * m, handle.get_stream()); rmm::device_uvector R_knn_dists(k * m, handle.get_stream()); + rmm::device_uvector out_inds_full(k * k * m, handle.get_stream()); + rmm::device_uvector out_dists_full(k * k * m, handle.get_stream()); + + n_samples = n_samples < 1 ? int(sqrt(m)) : n_samples; ASSERT(n_samples >= k, "number of landmark samples must be >= k"); - printf("Sampling %d points\n", n_samples); +// printf("Sampling %d points\n", n_samples); rmm::device_uvector R_indices(n_samples, handle.get_stream()); rmm::device_uvector R(n_samples * n, handle.get_stream()); @@ -218,10 +221,10 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, raft::matrix::copyRows(X, m, n, R.data(), R_indices.data(), n_samples, handle.get_stream(), true); - - raft::print_device_vector("sampled indices", R_indices.data(), - R_indices.size(), std::cout); - raft::print_device_vector("sampled rows", R.data(), R.size(), std::cout); +// +// raft::print_device_vector("sampled indices", R_indices.data(), +// R_indices.size(), std::cout); +// raft::print_device_vector("sampled rows", R.data(), R.size(), std::cout); /** * 2. Perform knn = bfknn(X, R, k) @@ -240,10 +243,10 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, R_knn_inds_64b.data() + R_knn_inds_64b.size(), R_knn_inds.data(), [] __device__(int64_t i) { return i; }); - raft::print_device_vector("R_knn_inds", R_knn_inds.data(), R_knn_inds.size(), - std::cout); - raft::print_device_vector("R_knn_dists", R_knn_dists.data(), - R_knn_dists.size(), std::cout); +// raft::print_device_vector("R_knn_inds", R_knn_inds.data(), R_knn_inds.size(), +// std::cout); +// raft::print_device_vector("R_knn_dists", R_knn_dists.data(), +// R_knn_dists.size(), std::cout); /** * 3. Create L_r = knn[:,0].T (CSR) @@ -252,7 +255,7 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, * Secondary sort by (R_knn_inds, R_knn_dists) */ - printf("Building representative lists for landmarks\n"); +// printf("Building representative lists for landmarks\n"); rmm::device_uvector R_1nn_inds(m, handle.get_stream()); rmm::device_uvector R_1nn_dists(m, handle.get_stream()); @@ -267,45 +270,38 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, R_1nn_dists_ptr[i] = R_knn_dists_ptr[i * k]; }); - // I think this might be assuming col-major? - // raft::matrix::sliceMatrix(R_knn_inds.data(), m, k, R_1nn_inds.data(), 0, 0, - // m+1, 1, handle.get_stream()); - // raft::matrix::sliceMatrix( - // R_knn_dists.data(), m, k, R_1nn_dists.data(), 0, - // 0, m+1, 1, handle.get_stream()); - // auto keys = thrust::make_zip_iterator( thrust::make_tuple(R_1nn_inds.data(), R_1nn_dists.data())); auto vals = thrust::make_zip_iterator(thrust::make_tuple(R_1nn_cols.data())); - raft::print_device_vector("R_1nn_inds", R_1nn_inds.data(), R_1nn_inds.size(), - std::cout); - raft::print_device_vector("R_1nn_cols", R_1nn_cols.data(), R_1nn_cols.size(), - std::cout); - - raft::print_device_vector("R_1nn_dists", R_1nn_dists.data(), - R_1nn_dists.size(), std::cout); +// raft::print_device_vector("R_1nn_inds", R_1nn_inds.data(), R_1nn_inds.size(), +// std::cout); +// raft::print_device_vector("R_1nn_cols", R_1nn_cols.data(), R_1nn_cols.size(), +// std::cout); +// +// raft::print_device_vector("R_1nn_dists", R_1nn_dists.data(), +// R_1nn_dists.size(), std::cout); // group neighborhoods for each reference landmark and sort each group by distance thrust::sort_by_key(thrust::cuda::par.on(handle.get_stream()), keys, keys + R_1nn_inds.size(), vals, NNComp()); - raft::print_device_vector("R_1nn_inds_sorted", R_1nn_inds.data(), - R_1nn_inds.size(), std::cout); - raft::print_device_vector("R_1nn_cols_sorted", R_1nn_cols.data(), - R_1nn_cols.size(), std::cout); - - raft::print_device_vector("R_1nn_dists_sorted", R_1nn_dists.data(), - R_1nn_dists.size(), std::cout); +// raft::print_device_vector("R_1nn_inds_sorted", R_1nn_inds.data(), +// R_1nn_inds.size(), std::cout); +// raft::print_device_vector("R_1nn_cols_sorted", R_1nn_cols.data(), +// R_1nn_cols.size(), std::cout); +// +// raft::print_device_vector("R_1nn_dists_sorted", R_1nn_dists.data(), +// R_1nn_dists.size(), std::cout); // convert to CSR for fast lookup rmm::device_uvector R_indptr(n_samples, handle.get_stream()); raft::sparse::convert::sorted_coo_to_csr( R_1nn_inds.data(), m, R_indptr.data(), n_samples + 1, handle.get_device_allocator(), handle.get_stream()); - - raft::print_device_vector("R_1nn_indptr", R_indptr.data(), R_indptr.size(), - std::cout); +// +// raft::print_device_vector("R_1nn_indptr", R_indptr.data(), R_indptr.size(), +// std::cout); /** * 4. Perform k-select over original KNN, using L_r to filter distances @@ -319,9 +315,16 @@ void random_ball_cover(const raft::handle_t &handle, const value_t *X, printf("Performing final bfknn\n"); - rbc_kernel<<>>( + /** + * Compute nearest k for each nearest landmark R and reduce them to a final + * k + */ + rbc_kernel<<>>( X, n, R_knn_inds.data(), R_knn_dists.data(), m, k, R_indptr.data(), - R_1nn_cols.data(), R_1nn_dists.data(), inds, dists, R_indices.data()); + R_1nn_cols.data(), R_1nn_dists.data(), out_inds_full.data(), out_dists_full.data(), R_indices.data()); +// +// raft::print_device_vector("out_inds", out_inds_full.data(), out_inds_full.size(), std::cout); +// raft::print_device_vector("out_dists", out_dists_full.data(), out_dists_full.size(), std::cout); // Thoughts: // For n_cols < 32, we could probably just have each thread compute the distance diff --git a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh index 3b7c0aacd4..33a311884e 100644 --- a/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh @@ -34,6 +34,7 @@ #include #include "haversine_distance.cuh" +#include "ball_cover.cuh" #include "processing.hpp" namespace raft { diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index dd6435e67e..5d966884c6 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -82,8 +82,8 @@ inline void brute_force_knn( template inline void random_ball_cover(const raft::handle_t &handle, const value_t *X, value_idx m, value_idx n, int k, value_idx *inds, - value_t *dists) { - detail::random_ball_cover(handle, X, m, n, k, inds, dists); + value_t *dists, value_idx n_samples=-1) { + detail::random_ball_cover(handle, X, m, n, k, inds, dists, n_samples); } } // namespace knn diff --git a/cpp/test/spatial/ball_cover.cu b/cpp/test/spatial/ball_cover.cu index 1997221ae4..e28baae726 100644 --- a/cpp/test/spatial/ball_cover.cu +++ b/cpp/test/spatial/ball_cover.cu @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include #include "../test_utils.h" @@ -73,7 +73,7 @@ class BallCoverKNNTest : public ::testing::Test { std::vector input_vec = {d_train_inputs}; std::vector sizes_vec = {n}; - raft::spatial::knn::detail::random_ball_cover(handle, d_train_inputs, n, d, + raft::spatial::knn::random_ball_cover(handle, d_train_inputs, n, d, k, d_pred_I, d_pred_D, s); CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); From 81c11df7f0f8ae746cdf42f0edced54abe8e7884 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Tue, 22 Jun 2021 13:02:31 -0400 Subject: [PATCH 05/55] Prototype of random ball cover so far. --- cpp/include/raft/cache/cache_util.cuh | 6 +- cpp/include/raft/sparse/linalg/degree.cuh | 10 +- cpp/include/raft/sparse/selection/knn.cuh | 7 +- .../raft/sparse/selection/knn_graph.cuh | 3 +- .../knn/detail/ann_quantized_faiss.cuh | 2 + .../raft/spatial/knn/detail/ball_cover.cuh | 117 +++++------------- .../knn/detail/knn_brute_force_faiss.cuh | 1 - .../knn/detail/selection_faiss.cuh} | 21 +--- cpp/include/raft/spatial/knn/knn.hpp | 75 ++++++++++- cpp/test/spatial/ball_cover.cu | 8 +- 10 files changed, 127 insertions(+), 123 deletions(-) rename cpp/include/raft/{sparse/selection/selection.cuh => spatial/knn/detail/selection_faiss.cuh} (93%) diff --git a/cpp/include/raft/cache/cache_util.cuh b/cpp/include/raft/cache/cache_util.cuh index ce8ef9a095..a65227c402 100644 --- a/cpp/include/raft/cache/cache_util.cuh +++ b/cpp/include/raft/cache/cache_util.cuh @@ -41,9 +41,9 @@ namespace cache { * @param [in] n the number of elements that need to be collected * @param [out] out vectors collected from the cache, size [n_vec * n] */ -template -__global__ void get_vecs(const math_t *cache, int n_vec, const int *cache_idx, - int n, math_t *out) { +template +__global__ void get_vecs(const math_t *cache, int_t n_vec, + const idx_t *cache_idx, int_t n, math_t *out) { int tid = threadIdx.x + blockIdx.x * blockDim.x; int row = tid % n_vec; // row idx if (tid < n_vec * n) { diff --git a/cpp/include/raft/sparse/linalg/degree.cuh b/cpp/include/raft/sparse/linalg/degree.cuh index 9bd322c90a..ef6a067c39 100644 --- a/cpp/include/raft/sparse/linalg/degree.cuh +++ b/cpp/include/raft/sparse/linalg/degree.cuh @@ -43,11 +43,11 @@ namespace linalg { * @param nnz the size of the rows array * @param results array to place results */ -template -__global__ void coo_degree_kernel(const int *rows, int nnz, int *results) { +template +__global__ void coo_degree_kernel(const T *rows, int nnz, T *results) { int row = (blockIdx.x * TPB_X) + threadIdx.x; if (row < nnz) { - raft::myAtomicAdd(results + rows[row], 1); + atomicAdd(results + rows[row], (T)1); } } @@ -59,8 +59,8 @@ __global__ void coo_degree_kernel(const int *rows, int nnz, int *results) { * @param results: output result array * @param stream: cuda stream to use */ -template -void coo_degree(const int *rows, int nnz, int *results, cudaStream_t stream) { +template +void coo_degree(const T *rows, int nnz, T *results, cudaStream_t stream) { dim3 grid_rc(raft::ceildiv(nnz, TPB_X), 1, 1); dim3 blk_rc(TPB_X, 1, 1); diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index e327386d13..3b53ed71d0 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -32,8 +32,6 @@ #include #include #include -#include - #include #include @@ -336,8 +334,9 @@ class sparse_knn_t { if (metric == raft::distance::DistanceType::InnerProduct) ascending = false; // kernel to slice first (min) k cols and copy into batched merge buffer - select_k(batch_dists, batch_indices, batch_rows, batch_cols, out_dists, - out_indices, ascending, n_neighbors, stream); + spatial::knn::select_k(batch_dists, batch_indices, batch_rows, batch_cols, + out_dists, out_indices, ascending, n_neighbors, + stream); } void compute_distances(csr_batcher_t &idx_batcher, diff --git a/cpp/include/raft/sparse/selection/knn_graph.cuh b/cpp/include/raft/sparse/selection/knn_graph.cuh index 1cf225087a..55edd2b721 100644 --- a/cpp/include/raft/sparse/selection/knn_graph.cuh +++ b/cpp/include/raft/sparse/selection/knn_graph.cuh @@ -88,11 +88,12 @@ void conv_indices(in_t *inds, out_t *out, size_t size, cudaStream_t stream) { * @param[in] n number of observations (columns) in X * @param[in] metric distance metric to use when constructing neighborhoods * @param[out] out output edge list + * @param[out] out output edge list * @param c */ template void knn_graph(const handle_t &handle, const value_t *X, size_t m, size_t n, - distance::DistanceType metric, + raft::distance::DistanceType metric, raft::sparse::COO &out, int c = 15) { int k = build_k(m, c); diff --git a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh index bb37fd03db..3d8c6dfdde 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_quantized_faiss.cuh @@ -17,9 +17,11 @@ #pragma once #include "../ann_common.h" +#include "knn_brute_force_faiss.cuh" #include #include +#include "processing.hpp" #include