diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 1559e9776f..92492dc37a 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -42,6 +43,20 @@ namespace raft { namespace sparse { namespace distance { +static const std::unordered_set supportedDistance{ + raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2Unexpanded, + raft::distance::DistanceType::L2SqrtExpanded, + raft::distance::DistanceType::L2SqrtUnexpanded, + raft::distance::DistanceType::InnerProduct, + raft::distance::DistanceType::L1, + raft::distance::DistanceType::Canberra, + raft::distance::DistanceType::Linf, + raft::distance::DistanceType::LpUnexpanded, + raft::distance::DistanceType::JaccardExpanded, + raft::distance::DistanceType::CosineExpanded, + raft::distance::DistanceType::HellingerExpanded}; + /** * Compute pairwise distances between A and B, using the provided * input configuration and distance function. @@ -60,12 +75,20 @@ void pairwiseDistance(value_t *out, case raft::distance::DistanceType::L2Expanded: l2_expanded_distances_t(input_config).compute(out); break; + case raft::distance::DistanceType::L2SqrtExpanded: + l2_sqrt_expanded_distances_t(input_config) + .compute(out); + break; case raft::distance::DistanceType::InnerProduct: ip_distances_t(input_config).compute(out); break; case raft::distance::DistanceType::L2Unexpanded: l2_unexpanded_distances_t(input_config).compute(out); break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + l2_sqrt_unexpanded_distances_t(input_config) + .compute(out); + break; case raft::distance::DistanceType::L1: l1_unexpanded_distances_t(input_config).compute(out); break; diff --git a/cpp/include/raft/sparse/distance/ip_distance.cuh b/cpp/include/raft/sparse/distance/ip_distance.cuh index a832c2b6a9..90717bfc5f 100644 --- a/cpp/include/raft/sparse/distance/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/ip_distance.cuh @@ -190,18 +190,12 @@ class ip_distances_gemm_t : public ip_trans_getters_t { value_t *csr_out_data) { value_idx m = config_->a_nrows, n = config_->b_nrows, k = config_->a_ncols; - int start = raft::curTimeMillis(); - - CUDA_CHECK(cudaStreamSynchronize(config_->stream)); - CUSPARSE_CHECK(raft::sparse::cusparsecsrgemm2( config_->handle, m, n, k, &alpha, matA, config_->a_nnz, config_->a_data, config_->a_indptr, config_->a_indices, matB, config_->b_nnz, csc_data.data(), csc_indptr.data(), csc_indices.data(), NULL, matD, 0, NULL, NULL, NULL, matC, csr_out_data, csr_out_indptr, csr_out_indices, info, workspace.data(), config_->stream)); - - CUDA_CHECK(cudaStreamSynchronize(config_->stream)); } void transpose_b() { diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 9d481e34ef..98187576fa 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -16,18 +16,13 @@ #pragma once -#include -#include - -#include -#include +#include #include #include #include #include #include - #include #include @@ -149,12 +144,39 @@ class l2_expanded_distances_t : public distances_t { ~l2_expanded_distances_t() = default; - private: + protected: const distances_config_t *config_; raft::mr::device::buffer workspace; ip_distances_t ip_dists; }; +/** + * L2 sqrt distance performing the sqrt operation after the distance computation + * The expanded form is more efficient for sparse data. + */ +template +class l2_sqrt_expanded_distances_t + : public l2_expanded_distances_t { + public: + explicit l2_sqrt_expanded_distances_t( + const distances_config_t &config) + : l2_expanded_distances_t(config) {} + + void compute(value_t *out_dists) override { + l2_expanded_distances_t::compute(out_dists); + // Sqrt Post-processing + raft::linalg::unaryOp( + out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, + [] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return sqrt(abs(input) * neg); + }, + this->config_->stream); + } + + ~l2_sqrt_expanded_distances_t() = default; +}; + /** * Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * sqrt(sum(y_k)^2))) * The expanded form is more efficient for sparse data. diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/lp_distance.cuh index e991224f1b..e524d87b7c 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/lp_distance.cuh @@ -124,10 +124,31 @@ class l2_unexpanded_distances_t : public distances_t { Sum(), AtomicAdd()); } - private: + protected: const distances_config_t *config_; }; +template +class l2_sqrt_unexpanded_distances_t + : public l2_unexpanded_distances_t { + public: + l2_sqrt_unexpanded_distances_t( + const distances_config_t &config) + : l2_unexpanded_distances_t(config) {} + + void compute(value_t *out_dists) { + l2_unexpanded_distances_t::compute(out_dists); + // Sqrt Post-processing + raft::linalg::unaryOp( + out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, + [] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return sqrt(abs(input) * neg); + }, + this->config_->stream); + } +}; + template class linf_unexpanded_distances_t : public distances_t { public: diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 3e8fa2bd6f..b309840d80 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include @@ -34,13 +36,10 @@ #include #include -#include - #include - -#include - +#include #include +#include #include @@ -128,7 +127,7 @@ class sparse_knn_t { size_t batch_size_query_ = 2 << 14, raft::distance::DistanceType metric_ = raft::distance::DistanceType::L2Expanded, - float metricArg_ = 0, bool expanded_form_ = false) + float metricArg_ = 0) : idxIndptr(idxIndptr_), idxIndices(idxIndices_), idxData(idxData_), @@ -150,8 +149,7 @@ class sparse_knn_t { batch_size_index(batch_size_index_), batch_size_query(batch_size_query_), metric(metric_), - metricArg(metricArg_), - expanded_form(expanded_form_) {} + metricArg(metricArg_) {} void run() { using namespace raft::sparse; @@ -172,26 +170,23 @@ class sparse_knn_t { * Slice CSR to rows in batch */ - raft::mr::device::buffer query_batch_indptr( - allocator, stream, query_batcher.batch_rows() + 1); + rmm::device_uvector query_batch_indptr( + query_batcher.batch_rows() + 1, stream); value_idx n_query_batch_nnz = query_batcher.get_batch_csr_indptr_nnz( query_batch_indptr.data(), stream); - raft::mr::device::buffer query_batch_indices( - allocator, stream, n_query_batch_nnz); - raft::mr::device::buffer query_batch_data(allocator, stream, - n_query_batch_nnz); + rmm::device_uvector query_batch_indices(n_query_batch_nnz, + stream); + rmm::device_uvector query_batch_data(n_query_batch_nnz, stream); query_batcher.get_batch_csr_indices_data(query_batch_indices.data(), query_batch_data.data(), stream); // A 3-partition temporary merge space to scale the batching. 2 parts for subsequent // batches and 1 space for the results of the merge, which get copied back to the top - raft::mr::device::buffer merge_buffer_indices(allocator, - stream, 0); - raft::mr::device::buffer merge_buffer_dists(allocator, stream, - 0); + rmm::device_uvector merge_buffer_indices(0, stream); + rmm::device_uvector merge_buffer_dists(0, stream); value_t *dists_merge_buffer_ptr; value_idx *indices_merge_buffer_ptr; @@ -209,11 +204,10 @@ class sparse_knn_t { /** * Slice CSR to rows in batch */ - raft::mr::device::buffer idx_batch_indptr( - allocator, stream, idx_batcher.batch_rows() + 1); - raft::mr::device::buffer idx_batch_indices(allocator, stream, - 0); - raft::mr::device::buffer idx_batch_data(allocator, stream, 0); + rmm::device_uvector idx_batch_indptr( + idx_batcher.batch_rows() + 1, stream); + rmm::device_uvector idx_batch_indices(0, stream); + rmm::device_uvector idx_batch_data(0, stream); value_idx idx_batch_nnz = idx_batcher.get_batch_csr_indptr_nnz(idx_batch_indptr.data(), stream); @@ -229,8 +223,7 @@ class sparse_knn_t { */ size_t dense_size = idx_batcher.batch_rows() * query_batcher.batch_rows(); - raft::mr::device::buffer batch_dists(allocator, stream, - dense_size); + rmm::device_uvector batch_dists(dense_size, stream); CUDA_CHECK(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t))); @@ -241,13 +234,9 @@ class sparse_knn_t { query_batch_indptr.data(), query_batch_indices.data(), query_batch_data.data(), batch_dists.data()); - idx_batch_indptr.release(stream); - idx_batch_indices.release(stream); - idx_batch_data.release(stream); - // Build batch indices array - raft::mr::device::buffer batch_indices(allocator, stream, - batch_dists.size()); + rmm::device_uvector batch_indices(batch_dists.size(), + stream); // populate batch indices array value_idx batch_rows = query_batcher.batch_rows(), @@ -268,8 +257,6 @@ class sparse_knn_t { batch_indices.data(), dists_merge_buffer_ptr, indices_merge_buffer_ptr); - perform_postprocessing(dists_merge_buffer_ptr, batch_rows); - value_t *dists_merge_buffer_tmp_ptr = dists_merge_buffer_ptr; value_idx *indices_merge_buffer_tmp_ptr = indices_merge_buffer_ptr; @@ -307,23 +294,6 @@ class sparse_knn_t { } } - void perform_postprocessing(value_t *dists, size_t batch_rows) { - // Perform necessary post-processing - if (metric == raft::distance::DistanceType::L2Expanded && !expanded_form) { - /** - * post-processing - */ - value_t p = 0.5; // standard l2 - raft::linalg::unaryOp( - dists, dists, batch_rows * k, - [p] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; - }, - stream); - } - } - private: void merge_batches(csr_batcher_t &idx_batcher, csr_batcher_t &query_batcher, @@ -335,8 +305,7 @@ class sparse_knn_t { id_ranges.push_back(0); id_ranges.push_back(idx_batcher.batch_start()); - raft::mr::device::buffer trans(allocator, stream, - id_ranges.size()); + rmm::device_uvector trans(id_ranges.size(), stream); raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), stream); @@ -403,6 +372,10 @@ class sparse_knn_t { dist_config.allocator = allocator; dist_config.stream = stream; + if (raft::sparse::distance::supportedDistance.find(metric) == + raft::sparse::distance::supportedDistance.end()) + THROW("DistanceType not supported: %d", metric); + raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, metric, metricArg); } @@ -418,8 +391,6 @@ class sparse_knn_t { float metricArg; - bool expanded_form; - int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k; cusparseHandle_t cusparseHandle; @@ -453,7 +424,6 @@ class sparse_knn_t { * @param[in] batch_size_query maximum number of rows to use from query matrix per batch * @param[in] metric distance metric/measure to use * @param[in] metricArg potential argument for metric (currently unused) - * @param[in] expanded_form whether or not Lp variants should be reduced by the pth-root */ template void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, @@ -469,12 +439,12 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, size_t batch_size_query = 2 << 14, raft::distance::DistanceType metric = raft::distance::DistanceType::L2Expanded, - float metricArg = 0, bool expanded_form = false) { + float metricArg = 0) { sparse_knn_t( idxIndptr, idxIndices, idxData, idxNNZ, n_idx_rows, n_idx_cols, queryIndptr, queryIndices, queryData, queryNNZ, n_query_rows, n_query_cols, output_indices, output_dists, k, cusparseHandle, allocator, stream, - batch_size_index, batch_size_query, metric, metricArg, expanded_form) + batch_size_index, batch_size_query, metric, metricArg) .run(); } diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index a841da661d..6e3f3b5038 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -155,6 +155,7 @@ class SparseDistanceCOOSPMVTest CUDA_CHECK(cudaStreamCreate(&stream)); CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); + CUSPARSE_CHECK(cusparseSetStream(cusparseHandle, stream)); make_data(); @@ -225,7 +226,8 @@ const std::vector> inputs_i32_f = { {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0}, - raft::distance::DistanceType::InnerProduct}, + raft::distance::DistanceType::InnerProduct, + 0.0}, {2, {0, 2, 4, 6, 8}, {0, 1, 0, 1, 0, 1, 0, 1}, // indices @@ -249,7 +251,8 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Unexpanded}, + raft::distance::DistanceType::L2Unexpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -362,7 +365,8 @@ const std::vector> inputs_i32_f = { 6.903282911791188, 7.0, 0.0}, - raft::distance::DistanceType::Canberra}, + raft::distance::DistanceType::Canberra, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -589,7 +593,8 @@ const std::vector> inputs_i32_f = { 0.5079750812968089, 0.8429599432532096, 0.0}, - raft::distance::DistanceType::Linf}, + raft::distance::DistanceType::Linf, + 0.0}, {4, {0, 1, 1, 2, 4}, @@ -614,7 +619,8 @@ const std::vector> inputs_i32_f = { 0.84454, 0.0, }, - raft::distance::DistanceType::L1} + raft::distance::DistanceType::L1, + 0.0} }; diff --git a/cpp/test/sparse/dist_csr_spmv.cu b/cpp/test/sparse/dist_csr_spmv.cu index 2405909c40..c32748a04e 100644 --- a/cpp/test/sparse/dist_csr_spmv.cu +++ b/cpp/test/sparse/dist_csr_spmv.cu @@ -229,7 +229,8 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Unexpanded}, + raft::distance::DistanceType::L2Unexpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -342,7 +343,8 @@ const std::vector> inputs_i32_f = { 6.903282911791188, 7.0, 0.0}, - raft::distance::DistanceType::Canberra}, + raft::distance::DistanceType::Canberra, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -569,7 +571,8 @@ const std::vector> inputs_i32_f = { 0.5079750812968089, 0.8429599432532096, 0.0}, - raft::distance::DistanceType::Linf}, + raft::distance::DistanceType::Linf, + 0.0}, {4, {0, 1, 1, 2, 4}, @@ -594,7 +597,8 @@ const std::vector> inputs_i32_f = { 0.84454, 0.0, }, - raft::distance::DistanceType::L1} + raft::distance::DistanceType::L1, + 0.0} }; diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 53e8838b65..b103486b96 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -88,6 +88,7 @@ class SparseDistanceTest CUDA_CHECK(cudaStreamCreate(&stream)); CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); + CUSPARSE_CHECK(cusparseSetStream(cusparseHandle, stream)); make_data(); @@ -127,15 +128,9 @@ class SparseDistanceTest } void compare() { - // skip Hellinger test due to sporadic CI issue - // https://github.com/rapidsai/cuml/issues/3477 - if (params.metric == raft::distance::DistanceType::HellingerExpanded) { - GTEST_SKIP(); - } else { - ASSERT_TRUE(devArrMatch(out_dists_ref, out_dists, - params.out_dists_ref_h.size(), - CompareApprox(1e-3))); - } + ASSERT_TRUE(devArrMatch(out_dists_ref, out_dists, + params.out_dists_ref_h.size(), + CompareApprox(1e-3))); } protected: @@ -176,14 +171,16 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Expanded}, + raft::distance::DistanceType::L2Expanded, + 0.0}, {2, {0, 2, 4, 6, 8}, {0, 1, 0, 1, 0, 1, 0, 1}, {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0}, - raft::distance::DistanceType::InnerProduct}, + raft::distance::DistanceType::InnerProduct, + 0.0}, {2, {0, 2, 4, 6, 8}, {0, 1, 0, 1, 0, 1, 0, 1}, // indices @@ -207,7 +204,8 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Unexpanded}, + raft::distance::DistanceType::L2Unexpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -237,7 +235,8 @@ const std::vector> inputs_i32_f = { 0.67676228, 0.24558392, 0.76064776, 0.51360432, 0., 1., 0.76978799, 0.78021386, 1., 0.84923694, 0.73155632, 0.99166225, 0.61547536, 0.68185144, 1., 0.}, - raft::distance::DistanceType::CosineExpanded}, + raft::distance::DistanceType::CosineExpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -347,7 +346,8 @@ const std::vector> inputs_i32_f = { 0.75, 1.0, 0.0}, - raft::distance::DistanceType::JaccardExpanded}, + raft::distance::DistanceType::JaccardExpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -460,7 +460,8 @@ const std::vector> inputs_i32_f = { 6.903282911791188, 7.0, 0.0}, - raft::distance::DistanceType::Canberra}, + raft::distance::DistanceType::Canberra, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -687,33 +688,10 @@ const std::vector> inputs_i32_f = { 0.5079750812968089, 0.8429599432532096, 0.0}, - raft::distance::DistanceType::Linf}, + raft::distance::DistanceType::Linf, + 0.0}, - {4, - {0, 1, 1, 2, 4}, - {3, 2, 0, 1}, // indices - {0.99296, 0.42180, 0.11687, 0.305869}, - { - // dense output - 0.0, - 0.99296, - 1.41476, - 1.415707, - 0.99296, - 0.0, - 0.42180, - 0.42274, - 1.41476, - 0.42180, - 0.0, - 0.84454, - 1.41570, - 0.42274, - 0.84454, - 0.0, - }, - raft::distance::DistanceType::L1}, - {10, + {15, {0, 5, 8, 9, 15, 20, 26, 31, 34, 38, 45}, {0, 1, 5, 6, 9, 1, 4, 14, 7, 3, 4, 7, 9, 11, 14, 0, 3, 7, 8, 12, 0, 2, 5, 7, 8, 14, 4, 9, 10, 11, @@ -752,7 +730,34 @@ const std::vector> inputs_i32_f = { 1.00000000e+00, 8.05419635e-01, 9.53789212e-01, 8.07933016e-01, 7.40428532e-01, 7.95485011e-01, 8.51370877e-01, 1.49011612e-08}, // Dataset is L1 normalized into pdfs - raft::distance::DistanceType::HellingerExpanded}}; + raft::distance::DistanceType::HellingerExpanded, + 0.0}, + + {4, + {0, 1, 1, 2, 4}, + {3, 2, 0, 1}, // indices + {0.99296, 0.42180, 0.11687, 0.305869}, + { + // dense output + 0.0, + 0.99296, + 1.41476, + 1.415707, + 0.99296, + 0.0, + 0.42180, + 0.42274, + 1.41476, + 0.42180, + 0.0, + 0.84454, + 1.41570, + 0.42274, + 0.84454, + 0.0, + }, + raft::distance::DistanceType::L1, + 0.0}}; typedef SparseDistanceTest SparseDistanceTestF; TEST_P(SparseDistanceTestF, Result) { compare(); } diff --git a/cpp/test/sparse/knn.cu b/cpp/test/sparse/knn.cu index 0f773b9fee..4759eebe4b 100644 --- a/cpp/test/sparse/knn.cu +++ b/cpp/test/sparse/knn.cu @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include "../test_utils.h" @@ -49,7 +51,7 @@ struct SparseKNNInputs { int batch_size_query = 2; raft::distance::DistanceType metric = - raft::distance::DistanceType::L2Expanded; + raft::distance::DistanceType::L2SqrtExpanded; }; template @@ -67,14 +69,10 @@ class SparseKNNTest std::vector indices_h = params.indices_h; std::vector data_h = params.data_h; - printf("Allocating input\n"); - allocate(indptr, indptr_h.size()); allocate(indices, indices_h.size()); allocate(data, data_h.size()); - printf("Updating device\n"); - update_device(indptr, indptr_h.data(), indptr_h.size(), stream); update_device(indices, indices_h.data(), indices_h.size(), stream); update_device(data, data_h.data(), data_h.size(), stream); @@ -82,23 +80,16 @@ class SparseKNNTest std::vector out_dists_ref_h = params.out_dists_ref_h; std::vector out_indices_ref_h = params.out_indices_ref_h; - printf("Allocating ref output\n"); allocate(out_indices_ref, out_indices_ref_h.size()); allocate(out_dists_ref, out_dists_ref_h.size()); - printf("Updating device\n"); - update_device(out_indices_ref, out_indices_ref_h.data(), out_indices_ref_h.size(), stream); update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), stream); - printf("Allocating final output\n"); - allocate(out_dists, n_rows * k); allocate(out_indices, n_rows * k); - - printf("Done.\n"); } void SetUp() override { @@ -106,7 +97,6 @@ class SparseKNNTest ::testing::TestWithParam>::GetParam(); std::shared_ptr alloc( new raft::mr::device::default_allocator); - CUDA_CHECK(cudaStreamCreate(&stream)); CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); @@ -115,12 +105,8 @@ class SparseKNNTest nnz = params.indices_h.size(); k = params.k; - printf("Making data\n"); - make_data(); - printf("About to run kselect\n"); - raft::sparse::selection::brute_force_knn( indptr, indices, data, nnz, n_rows, params.n_cols, indptr, indices, data, nnz, n_rows, params.n_cols, out_indices, out_dists, k, cusparseHandle, @@ -128,11 +114,11 @@ class SparseKNNTest params.metric); CUDA_CHECK(cudaStreamSynchronize(stream)); - - printf("Executed k-select"); } void TearDown() override { + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaFree(indptr)); CUDA_CHECK(cudaFree(indices)); CUDA_CHECK(cudaFree(data)); @@ -181,7 +167,7 @@ const std::vector> inputs_i32_f = { 2, 2, 2, - raft::distance::DistanceType::L2Expanded}}; + raft::distance::DistanceType::L2SqrtExpanded}}; typedef SparseKNNTest SparseKNNTestF; TEST_P(SparseKNNTestF, Result) { compare(); } INSTANTIATE_TEST_CASE_P(SparseKNNTest, SparseKNNTestF,