From 3d75b4e27b718f7a6e4f22502c04109bf53db931 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Mon, 15 Mar 2021 12:59:02 +0000 Subject: [PATCH 1/4] brute_force_knn update --- .../spatial/knn/detail/brute_force_knn.hpp | 247 ++++++++++++++---- cpp/include/raft/spatial/knn/knn.hpp | 12 +- 2 files changed, 195 insertions(+), 64 deletions(-) diff --git a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp index e686fff587..3708925354 100644 --- a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp +++ b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp @@ -166,27 +166,140 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK, inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); } -inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) { +inline faiss::MetricType build_faiss_metric( + raft::distance::DistanceType metric) { switch (metric) { - case distance::DistanceType::L2Unexpanded: + case raft::distance::DistanceType::CosineExpanded: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::CorrelationExpanded: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::L2Expanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2Unexpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2SqrtExpanded: return faiss::MetricType::METRIC_L2; - case distance::DistanceType::L1: + case raft::distance::DistanceType::L2SqrtUnexpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L1: return faiss::MetricType::METRIC_L1; - case distance::DistanceType::Linf: - return faiss::MetricType::METRIC_Linf; - case distance::DistanceType::LpUnexpanded: + case raft::distance::DistanceType::InnerProduct: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::LpUnexpanded: return faiss::MetricType::METRIC_Lp; - case distance::DistanceType::Canberra: + case raft::distance::DistanceType::Linf: + return faiss::MetricType::METRIC_Linf; + case raft::distance::DistanceType::Canberra: return faiss::MetricType::METRIC_Canberra; - case distance::DistanceType::BrayCurtis: + case raft::distance::DistanceType::BrayCurtis: return faiss::MetricType::METRIC_BrayCurtis; - case distance::DistanceType::JensenShannon: + case raft::distance::DistanceType::JensenShannon: return faiss::MetricType::METRIC_JensenShannon; default: - return faiss::MetricType::METRIC_INNER_PRODUCT; + THROW("MetricType not supported: %d", metric); } } +template +DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) { + value_t sin_0 = sin(0.5 * (x1 - y1)); + value_t sin_1 = sin(0.5 * (x2 - y2)); + value_t rdist = sin_0 * sin_0 + cos(x1) * cos(y1) * sin_1 * sin_1; + + return 2 * asin(sqrt(rdist)); +} + +/** + * @tparam value_idx data type of indices + * @tparam value_t data type of values and distances + * @tparam warp_q + * @tparam thread_q + * @tparam tpb + * @param[out] out_inds output indices + * @param[out] out_dists output distances + * @param[in] index index array + * @param[in] query query array + * @param[in] n_index_rows number of rows in index array + * @param[in] k number of closest neighbors to return + */ +template +__global__ void haversine_knn_kernel(value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, + size_t n_index_rows, int k) { + constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + faiss::gpu::BlockSelect, warp_q, thread_q, + tpb> + heap(faiss::gpu::Limits::getMax(), -1, smemK, smemV, k); + + // Grid is exactly sized to rows available + int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize); + + const value_t *query_ptr = query + (blockIdx.x * 2); + value_t x1 = query_ptr[0]; + value_t x2 = query_ptr[1]; + + int i = threadIdx.x; + + for (; i < limit; i += tpb) { + const value_t *idx_ptr = index + (i * 2); + value_t y1 = idx_ptr[0]; + value_t y2 = idx_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + heap.add(dist, i); + } + + // Handle last remainder fraction of a warp of elements + if (i < n_index_rows) { + const value_t *idx_ptr = index + (i * 2); + value_t y1 = idx_ptr[0]; + value_t y2 = idx_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + heap.addThreadQ(dist, i); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + out_dists[blockIdx.x * k + i] = smemK[i]; + out_inds[blockIdx.x * k + i] = smemV[i]; + } +} + +/** + * Conmpute the k-nearest neighbors using the Haversine + * (great circle arc) distance. Input is assumed to have + * 2 dimensions (latitude, longitude) in radians. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * 2) + * @param[in] query input query array on device (size n_query_rows * 2) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] stream stream to order kernel launch + */ +template +void haversine_knn(value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, + size_t n_index_rows, size_t n_query_rows, int k, + cudaStream_t stream) { + haversine_knn_kernel<<>>( + out_inds, out_dists, index, query, n_index_rows, k); +} + /** * Search the kNN for the k-nearest neighbors of a set of query vectors * @param[in] input vector of device device memory array pointers to search @@ -209,25 +322,25 @@ inline faiss::MetricType build_faiss_metric(distance::DistanceType metric) { * @param[in] rowMajorQuery are the query array in row-major layout? * @param[in] translations translation ids for indices when index rows represent * non-contiguous partitions - * @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean) + * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm - * @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root */ template -void brute_force_knn_impl( - std::vector &input, std::vector &sizes, IntType D, - float *search_items, IntType n, int64_t *res_I, float *res_D, IntType k, - std::shared_ptr allocator, - cudaStream_t userStream, cudaStream_t *internalStreams = nullptr, - int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true, - std::vector *translations = nullptr, - distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - float metricArg = 2.0, bool expanded_form = false) { +void brute_force_knn_impl(std::vector &input, std::vector &sizes, + IntType D, float *search_items, IntType n, + int64_t *res_I, float *res_D, IntType k, + std::shared_ptr allocator, + cudaStream_t userStream, + cudaStream_t *internalStreams = nullptr, + int n_int_streams = 0, bool rowMajorIndex = true, + bool rowMajorQuery = true, + std::vector *translations = nullptr, + raft::distance::DistanceType metric = + raft::distance::DistanceType::L2Expanded, + float metricArg = 0) { ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); - faiss::MetricType m = detail::build_faiss_metric(metric); - std::vector *id_ranges; if (translations == nullptr) { // If we don't have explicit translations @@ -235,7 +348,7 @@ void brute_force_knn_impl( // from the local partitions id_ranges = new std::vector(); int64_t total_n = 0; - for (size_t i = 0; i < input.size(); i++) { + for (int i = 0; i < input.size(); i++) { id_ranges->push_back(total_n); total_n += sizes[i]; } @@ -252,7 +365,7 @@ void brute_force_knn_impl( std::vector>> metric_processors( input.size()); - for (size_t i = 0; i < input.size(); i++) { + for (int i = 0; i < input.size(); i++) { metric_processors[i] = create_processor( metric, sizes[i], D, k, rowMajorQuery, userStream, allocator); metric_processors[i]->preprocess(input[i]); @@ -283,35 +396,52 @@ void brute_force_knn_impl( // Sync user stream only if using other streams to parallelize query if (n_int_streams > 0) CUDA_CHECK(cudaStreamSynchronize(userStream)); - for (size_t i = 0; i < input.size(); i++) { - faiss::gpu::StandardGpuResources gpu_res; + for (int i = 0; i < input.size(); i++) { + float *out_d_ptr = out_D + (i * k * n); + int64_t *out_i_ptr = out_I + (i * k * n); cudaStream_t stream = raft::select_stream(userStream, internalStreams, n_int_streams, i); - gpu_res.noTempMemory(); - gpu_res.setDefaultStream(device, stream); - - faiss::gpu::GpuDistanceParams args; - args.metric = m; - args.metricArg = metricArg; - args.k = k; - args.dims = D; - args.vectors = input[i]; - args.vectorsRowMajor = rowMajorIndex; - args.numVectors = sizes[i]; - args.queries = search_items; - args.queriesRowMajor = rowMajorQuery; - args.numQueries = n; - args.outDistances = out_D + (i * k * n); - args.outIndices = out_I + (i * k * n); - - /** - * @todo: Until FAISS supports pluggable allocation strategies, - * we will not reap the benefits of the pool allocator for - * avoiding device-wide synchronizations from cudaMalloc/cudaFree - */ - bfKnn(&gpu_res, args); + switch (metric) { + case raft::distance::DistanceType::Haversine: + + ASSERT(D == 2, + "Haversine distance requires 2 dimensions " + "(latitude / longitude)."); + + haversine_knn(out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, + k, stream); + break; + default: + faiss::MetricType m = build_faiss_metric(metric); + + faiss::gpu::StandardGpuResources gpu_res; + + gpu_res.noTempMemory(); + gpu_res.setDefaultStream(device, stream); + + faiss::gpu::GpuDistanceParams args; + args.metric = m; + args.metricArg = metricArg; + args.k = k; + args.dims = D; + args.vectors = input[i]; + args.vectorsRowMajor = rowMajorIndex; + args.numVectors = sizes[i]; + args.queries = search_items; + args.queriesRowMajor = rowMajorQuery; + args.numQueries = n; + args.outDistances = out_d_ptr; + args.outIndices = out_i_ptr; + + /** + * @todo: Until FAISS supports pluggable allocation strategies, + * we will not reap the benefits of the pool allocator for + * avoiding device-wide synchronizations from cudaMalloc/cudaFree + */ + bfKnn(&gpu_res, args); + } CUDA_CHECK(cudaPeekAtLastError()); } @@ -326,19 +456,20 @@ void brute_force_knn_impl( if (input.size() > 1 || translations != nullptr) { // This is necessary for proper index translations. If there are // no translations or partitions to combine, it can be skipped. - detail::knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, - userStream, trans.data()); + knn_merge_parts(out_D, out_I, res_D, res_I, n, input.size(), k, userStream, + trans.data()); } // Perform necessary post-processing - if ((m == faiss::MetricType::METRIC_L2 || - m == faiss::MetricType::METRIC_Lp) && - !expanded_form) { + if (metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::LpUnexpanded) { /** * post-processing */ float p = 0.5; // standard l2 - if (m == faiss::MetricType::METRIC_Lp) p = 1.0 / metricArg; + if (metric == raft::distance::DistanceType::LpUnexpanded) + p = 1.0 / metricArg; raft::linalg::unaryOp( res_D, res_D, n * k, [p] __device__(float input) { return powf(input, p); }, userStream); @@ -346,12 +477,12 @@ void brute_force_knn_impl( query_metric_processor->revert(search_items); query_metric_processor->postprocess(out_D); - for (size_t i = 0; i < input.size(); i++) { + for (int i = 0; i < input.size(); i++) { metric_processors[i]->revert(input[i]); } if (translations == nullptr) delete id_ranges; -} +}; } // namespace detail } // namespace knn diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index e1ab413b15..2ca3d29fe6 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -56,17 +56,17 @@ inline void brute_force_knn( float *res_D, int k, bool rowMajorIndex = false, bool rowMajorQuery = false, std::vector *translations = nullptr, distance::DistanceType metric = distance::DistanceType::L2Unexpanded, - float metric_arg = 2.0f, bool expanded = false) { + float metric_arg = 2.0f) { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); std::vector int_streams = handle.get_internal_streams(); - detail::brute_force_knn_impl( - input, sizes, D, search_items, n, res_I, res_D, k, - handle.get_device_allocator(), handle.get_stream(), int_streams.data(), - handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery, - translations, metric, metric_arg, expanded); + detail::brute_force_knn_impl(input, sizes, D, search_items, n, res_I, res_D, + k, handle.get_device_allocator(), + handle.get_stream(), int_streams.data(), + handle.get_num_internal_streams(), rowMajorIndex, + rowMajorQuery, translations, metric, metric_arg); } } // namespace knn From ab5a40a5c930a6ce419204d7f62b2c16c340dbbc Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 16 Mar 2021 09:52:32 +0000 Subject: [PATCH 2/4] Moving haversine --- .../spatial/knn/detail/brute_force_knn.hpp | 101 +------------ .../spatial/knn/detail/haversine_distance.hpp | 140 ++++++++++++++++++ 2 files changed, 141 insertions(+), 100 deletions(-) create mode 100644 cpp/include/raft/spatial/knn/detail/haversine_distance.hpp diff --git a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp index 3708925354..babb7b9eed 100644 --- a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp +++ b/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp @@ -33,6 +33,7 @@ #include #include +#include "haversine_distance.hpp" #include "processing.hpp" namespace raft { @@ -200,106 +201,6 @@ inline faiss::MetricType build_faiss_metric( } } -template -DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) { - value_t sin_0 = sin(0.5 * (x1 - y1)); - value_t sin_1 = sin(0.5 * (x2 - y2)); - value_t rdist = sin_0 * sin_0 + cos(x1) * cos(y1) * sin_1 * sin_1; - - return 2 * asin(sqrt(rdist)); -} - -/** - * @tparam value_idx data type of indices - * @tparam value_t data type of values and distances - * @tparam warp_q - * @tparam thread_q - * @tparam tpb - * @param[out] out_inds output indices - * @param[out] out_dists output distances - * @param[in] index index array - * @param[in] query query array - * @param[in] n_index_rows number of rows in index array - * @param[in] k number of closest neighbors to return - */ -template -__global__ void haversine_knn_kernel(value_idx *out_inds, value_t *out_dists, - const value_t *index, const value_t *query, - size_t n_index_rows, int k) { - constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; - - __shared__ value_t smemK[kNumWarps * warp_q]; - __shared__ value_idx smemV[kNumWarps * warp_q]; - - faiss::gpu::BlockSelect, warp_q, thread_q, - tpb> - heap(faiss::gpu::Limits::getMax(), -1, smemK, smemV, k); - - // Grid is exactly sized to rows available - int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize); - - const value_t *query_ptr = query + (blockIdx.x * 2); - value_t x1 = query_ptr[0]; - value_t x2 = query_ptr[1]; - - int i = threadIdx.x; - - for (; i < limit; i += tpb) { - const value_t *idx_ptr = index + (i * 2); - value_t y1 = idx_ptr[0]; - value_t y2 = idx_ptr[1]; - - value_t dist = compute_haversine(x1, y1, x2, y2); - - heap.add(dist, i); - } - - // Handle last remainder fraction of a warp of elements - if (i < n_index_rows) { - const value_t *idx_ptr = index + (i * 2); - value_t y1 = idx_ptr[0]; - value_t y2 = idx_ptr[1]; - - value_t dist = compute_haversine(x1, y1, x2, y2); - - heap.addThreadQ(dist, i); - } - - heap.reduce(); - - for (int i = threadIdx.x; i < k; i += tpb) { - out_dists[blockIdx.x * k + i] = smemK[i]; - out_inds[blockIdx.x * k + i] = smemV[i]; - } -} - -/** - * Conmpute the k-nearest neighbors using the Haversine - * (great circle arc) distance. Input is assumed to have - * 2 dimensions (latitude, longitude) in radians. - - * @tparam value_idx - * @tparam value_t - * @param[out] out_inds output indices array on device (size n_query_rows * k) - * @param[out] out_dists output dists array on device (size n_query_rows * k) - * @param[in] index input index array on device (size n_index_rows * 2) - * @param[in] query input query array on device (size n_query_rows * 2) - * @param[in] n_index_rows number of rows in index array - * @param[in] n_query_rows number of rows in query array - * @param[in] k number of closest neighbors to return - * @param[in] stream stream to order kernel launch - */ -template -void haversine_knn(value_idx *out_inds, value_t *out_dists, - const value_t *index, const value_t *query, - size_t n_index_rows, size_t n_query_rows, int k, - cudaStream_t stream) { - haversine_knn_kernel<<>>( - out_inds, out_dists, index, query, n_index_rows, k); -} - /** * Search the kNN for the k-nearest neighbors of a set of query vectors * @param[in] input vector of device device memory array pointers to search diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.hpp b/cpp/include/raft/spatial/knn/detail/haversine_distance.hpp new file mode 100644 index 0000000000..7d87254cb6 --- /dev/null +++ b/cpp/include/raft/spatial/knn/detail/haversine_distance.hpp @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2020-2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace raft { +namespace spatial { +namespace knn { +namespace detail { + +template +DI value_t compute_haversine(value_t x1, value_t y1, value_t x2, value_t y2) { + value_t sin_0 = sin(0.5 * (x1 - y1)); + value_t sin_1 = sin(0.5 * (x2 - y2)); + value_t rdist = sin_0 * sin_0 + cos(x1) * cos(y1) * sin_1 * sin_1; + + return 2 * asin(sqrt(rdist)); +} + +/** + * @tparam value_idx data type of indices + * @tparam value_t data type of values and distances + * @tparam warp_q + * @tparam thread_q + * @tparam tpb + * @param[out] out_inds output indices + * @param[out] out_dists output distances + * @param[in] index index array + * @param[in] query query array + * @param[in] n_index_rows number of rows in index array + * @param[in] k number of closest neighbors to return + */ +template +__global__ void haversine_knn_kernel(value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, + size_t n_index_rows, int k) { + constexpr int kNumWarps = tpb / faiss::gpu::kWarpSize; + + __shared__ value_t smemK[kNumWarps * warp_q]; + __shared__ value_idx smemV[kNumWarps * warp_q]; + + faiss::gpu::BlockSelect, warp_q, thread_q, + tpb> + heap(faiss::gpu::Limits::getMax(), -1, smemK, smemV, k); + + // Grid is exactly sized to rows available + int limit = faiss::gpu::utils::roundDown(n_index_rows, faiss::gpu::kWarpSize); + + const value_t *query_ptr = query + (blockIdx.x * 2); + value_t x1 = query_ptr[0]; + value_t x2 = query_ptr[1]; + + int i = threadIdx.x; + + for (; i < limit; i += tpb) { + const value_t *idx_ptr = index + (i * 2); + value_t y1 = idx_ptr[0]; + value_t y2 = idx_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + heap.add(dist, i); + } + + // Handle last remainder fraction of a warp of elements + if (i < n_index_rows) { + const value_t *idx_ptr = index + (i * 2); + value_t y1 = idx_ptr[0]; + value_t y2 = idx_ptr[1]; + + value_t dist = compute_haversine(x1, y1, x2, y2); + + heap.addThreadQ(dist, i); + } + + heap.reduce(); + + for (int i = threadIdx.x; i < k; i += tpb) { + out_dists[blockIdx.x * k + i] = smemK[i]; + out_inds[blockIdx.x * k + i] = smemV[i]; + } +} + +/** + * Conmpute the k-nearest neighbors using the Haversine + * (great circle arc) distance. Input is assumed to have + * 2 dimensions (latitude, longitude) in radians. + + * @tparam value_idx + * @tparam value_t + * @param[out] out_inds output indices array on device (size n_query_rows * k) + * @param[out] out_dists output dists array on device (size n_query_rows * k) + * @param[in] index input index array on device (size n_index_rows * 2) + * @param[in] query input query array on device (size n_query_rows * 2) + * @param[in] n_index_rows number of rows in index array + * @param[in] n_query_rows number of rows in query array + * @param[in] k number of closest neighbors to return + * @param[in] stream stream to order kernel launch + */ +template +void haversine_knn(value_idx *out_inds, value_t *out_dists, + const value_t *index, const value_t *query, + size_t n_index_rows, size_t n_query_rows, int k, + cudaStream_t stream) { + haversine_knn_kernel<<>>( + out_inds, out_dists, index, query, n_index_rows, k); +} + +} // namespace detail +} // namespace knn +} // namespace spatial +} // namespace raft From 9f3ab42d65cca8713d67db945c41ffd15aa9bfe3 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 16 Mar 2021 10:28:10 +0000 Subject: [PATCH 3/4] Haversine testing --- cpp/CMakeLists.txt | 1 + cpp/test/spatial/haversine.cu | 119 ++++++++++++++++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 cpp/test/spatial/haversine.cu diff --git a/cpp/CMakeLists.txt b/cpp/CMakeLists.txt index 5d01683f95..7658eb5bc6 100644 --- a/cpp/CMakeLists.txt +++ b/cpp/CMakeLists.txt @@ -297,6 +297,7 @@ if(BUILD_RAFT_TESTS) test/sparse/sort.cu test/sparse/symmetrize.cu test/spatial/knn.cu + test/spatial/haversine.cu test/stats/mean.cu test/stats/mean_center.cu test/stats/stddev.cu diff --git a/cpp/test/spatial/haversine.cu b/cpp/test/spatial/haversine.cu new file mode 100644 index 0000000000..44de78ed01 --- /dev/null +++ b/cpp/test/spatial/haversine.cu @@ -0,0 +1,119 @@ +/* + * Copyright (c) 2021, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include +#include +#include +#include +#include +#include "../test_utils.h" + +namespace raft { +namespace spatial { +namespace knn { + +template +class HaversineKNNTest : public ::testing::Test { + protected: + void basicTest() { + auto alloc = std::make_shared(); + + // Allocate input + raft::allocate(d_train_inputs, n * d); + + // Allocate reference arrays + raft::allocate(d_ref_I, n * n); + raft::allocate(d_ref_D, n * n); + + // Allocate predicted arrays + raft::allocate(d_pred_I, n * n); + raft::allocate(d_pred_D, n * n); + + // make testdata on host + std::vector h_train_inputs = { + 0.71113885, -1.29215058, 0.59613176, -2.08048115, + 0.74932804, -1.33634042, 0.51486728, -1.65962873, + 0.53154002, -1.47049808, 0.72891737, -1.54095137}; + + h_train_inputs.resize(n); + raft::update_device(d_train_inputs, h_train_inputs.data(), n * d, 0); + + std::vector h_res_D = { + 0., 0.05041587, 0.18767063, 0.23048252, 0.35749438, 0.62925595, + 0., 0.36575755, 0.44288665, 0.5170737, 0.59501296, 0.62925595, + 0., 0.05041587, 0.152463, 0.2426416, 0.34925285, 0.59501296, + 0., 0.16461092, 0.2345792, 0.34925285, 0.35749438, 0.36575755, + 0., 0.16461092, 0.20535265, 0.23048252, 0.2426416, 0.5170737, + 0., 0.152463, 0.18767063, 0.20535265, 0.2345792, 0.44288665}; + h_res_D.resize(n * n); + raft::update_device(d_ref_D, h_res_D.data(), n * n, 0); + + std::vector h_res_I = {0, 2, 5, 4, 3, 1, 1, 3, 5, 4, 2, 0, + 2, 0, 5, 4, 3, 1, 3, 4, 5, 2, 0, 1, + 4, 3, 5, 0, 2, 1, 5, 2, 0, 4, 3, 1}; + h_res_I.resize(n * n); + raft::update_device(d_ref_I, h_res_I.data(), n * n, 0); + + std::vector input_vec = {d_train_inputs}; + std::vector sizes_vec = {n}; + + cudaStream_t stream; + CUDA_CHECK(cudaStreamCreate(&stream)); + + raft::spatial::knn::detail::haversine_knn( + d_pred_I, d_pred_D, d_train_inputs, d_train_inputs, n, n, k, stream); + + CUDA_CHECK(cudaStreamDestroy(stream)); + } + + void SetUp() override { basicTest(); } + + void TearDown() override { + CUDA_CHECK(cudaFree(d_train_inputs)); + CUDA_CHECK(cudaFree(d_pred_I)); + CUDA_CHECK(cudaFree(d_pred_D)); + CUDA_CHECK(cudaFree(d_ref_I)); + CUDA_CHECK(cudaFree(d_ref_D)); + } + + protected: + value_t *d_train_inputs; + + int n = 6; + int d = 2; + + int k = 6; + + value_idx *d_pred_I; + value_t *d_pred_D; + + value_idx *d_ref_I; + value_t *d_ref_D; +}; + +typedef HaversineKNNTest HaversineKNNTestF; + +TEST_F(HaversineKNNTestF, Fit) { + ASSERT_TRUE(raft::devArrMatch(d_ref_D, d_pred_D, n * n, + raft::CompareApprox(1e-3))); + ASSERT_TRUE( + raft::devArrMatch(d_ref_I, d_pred_I, n * n, raft::Compare())); +} + +} // namespace knn +} // namespace spatial +} // namespace raft From bba8495487557a148f18251b4e8ddce1a805c970 Mon Sep 17 00:00:00 2001 From: viclafargue Date: Tue, 16 Mar 2021 18:27:57 +0000 Subject: [PATCH 4/4] Renaming files --- cpp/include/raft/sparse/selection/knn.cuh | 2 +- .../knn/detail/{brute_force_knn.hpp => brute_force_knn.cuh} | 2 +- .../detail/{haversine_distance.hpp => haversine_distance.cuh} | 0 cpp/include/raft/spatial/knn/knn.hpp | 2 +- cpp/test/spatial/haversine.cu | 2 +- 5 files changed, 4 insertions(+), 4 deletions(-) rename cpp/include/raft/spatial/knn/detail/{brute_force_knn.hpp => brute_force_knn.cuh} (99%) rename cpp/include/raft/spatial/knn/detail/{haversine_distance.hpp => haversine_distance.cuh} (100%) diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index b309840d80..93b2a83bbd 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -33,7 +33,7 @@ #include #include #include -#include +#include #include #include diff --git a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp b/cpp/include/raft/spatial/knn/detail/brute_force_knn.cuh similarity index 99% rename from cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp rename to cpp/include/raft/spatial/knn/detail/brute_force_knn.cuh index babb7b9eed..ebe17837cc 100644 --- a/cpp/include/raft/spatial/knn/detail/brute_force_knn.hpp +++ b/cpp/include/raft/spatial/knn/detail/brute_force_knn.cuh @@ -33,7 +33,7 @@ #include #include -#include "haversine_distance.hpp" +#include "haversine_distance.cuh" #include "processing.hpp" namespace raft { diff --git a/cpp/include/raft/spatial/knn/detail/haversine_distance.hpp b/cpp/include/raft/spatial/knn/detail/haversine_distance.cuh similarity index 100% rename from cpp/include/raft/spatial/knn/detail/haversine_distance.hpp rename to cpp/include/raft/spatial/knn/detail/haversine_distance.cuh diff --git a/cpp/include/raft/spatial/knn/knn.hpp b/cpp/include/raft/spatial/knn/knn.hpp index 2ca3d29fe6..066d1cff3a 100644 --- a/cpp/include/raft/spatial/knn/knn.hpp +++ b/cpp/include/raft/spatial/knn/knn.hpp @@ -16,7 +16,7 @@ #pragma once -#include "detail/brute_force_knn.hpp" +#include "detail/brute_force_knn.cuh" #include #include diff --git a/cpp/test/spatial/haversine.cu b/cpp/test/spatial/haversine.cu index 44de78ed01..def1f1685b 100644 --- a/cpp/test/spatial/haversine.cu +++ b/cpp/test/spatial/haversine.cu @@ -17,7 +17,7 @@ #include #include #include -#include +#include #include #include #include "../test_utils.h"