From fc155c06d3d7055158a640d6639bd4800d5264c4 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 3 Mar 2021 12:11:07 -0500 Subject: [PATCH 1/6] Moving Mickael's updates over to raft --- cpp/include/raft/sparse/distance/distance.cuh | 25 +++++- .../raft/sparse/distance/l2_distance.cuh | 50 +++++++++-- .../raft/sparse/distance/lp_distance.cuh | 29 ++++++- cpp/include/raft/sparse/selection/knn.cuh | 86 +++++++------------ 4 files changed, 121 insertions(+), 69 deletions(-) diff --git a/cpp/include/raft/sparse/distance/distance.cuh b/cpp/include/raft/sparse/distance/distance.cuh index 1559e9776f..92492dc37a 100644 --- a/cpp/include/raft/sparse/distance/distance.cuh +++ b/cpp/include/raft/sparse/distance/distance.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, NVIDIA CORPORATION. + * Copyright (c) 2020-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -42,6 +43,20 @@ namespace raft { namespace sparse { namespace distance { +static const std::unordered_set supportedDistance{ + raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2Unexpanded, + raft::distance::DistanceType::L2SqrtExpanded, + raft::distance::DistanceType::L2SqrtUnexpanded, + raft::distance::DistanceType::InnerProduct, + raft::distance::DistanceType::L1, + raft::distance::DistanceType::Canberra, + raft::distance::DistanceType::Linf, + raft::distance::DistanceType::LpUnexpanded, + raft::distance::DistanceType::JaccardExpanded, + raft::distance::DistanceType::CosineExpanded, + raft::distance::DistanceType::HellingerExpanded}; + /** * Compute pairwise distances between A and B, using the provided * input configuration and distance function. @@ -60,12 +75,20 @@ void pairwiseDistance(value_t *out, case raft::distance::DistanceType::L2Expanded: l2_expanded_distances_t(input_config).compute(out); break; + case raft::distance::DistanceType::L2SqrtExpanded: + l2_sqrt_expanded_distances_t(input_config) + .compute(out); + break; case raft::distance::DistanceType::InnerProduct: ip_distances_t(input_config).compute(out); break; case raft::distance::DistanceType::L2Unexpanded: l2_unexpanded_distances_t(input_config).compute(out); break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + l2_sqrt_unexpanded_distances_t(input_config) + .compute(out); + break; case raft::distance::DistanceType::L1: l1_unexpanded_distances_t(input_config).compute(out); break; diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 9d481e34ef..3947af1114 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -16,18 +16,13 @@ #pragma once -#include -#include - -#include -#include +#include #include #include #include #include #include - #include #include @@ -78,7 +73,7 @@ __global__ void compute_euclidean_warp_kernel( } template + typename expansion_f> void compute_euclidean(value_t *C, const value_t *Q_sq_norms, const value_t *R_sq_norms, value_idx n_rows, value_idx n_cols, cudaStream_t stream, @@ -89,7 +84,7 @@ void compute_euclidean(value_t *C, const value_t *Q_sq_norms, } template + typename expansion_f> void compute_l2(value_t *out, const value_idx *Q_coo_rows, const value_t *Q_data, value_idx Q_nnz, const value_idx *R_coo_rows, const value_t *R_data, @@ -127,17 +122,20 @@ class l2_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { + CUML_LOG_DEBUG("Computing inner products"); ip_dists.compute(out_dists); value_idx *b_indices = ip_dists.b_rows_coo(); value_t *b_data = ip_dists.b_data_coo(); + CUML_LOG_DEBUG("Computing COO row index array"); raft::mr::device::buffer search_coo_rows( config_->allocator, config_->stream, config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, config_->stream); + CUML_LOG_DEBUG("Computing L2"); compute_l2( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, @@ -149,12 +147,41 @@ class l2_expanded_distances_t : public distances_t { ~l2_expanded_distances_t() = default; - private: + protected: const distances_config_t *config_; raft::mr::device::buffer workspace; ip_distances_t ip_dists; }; +/** + * L2 sqrt distance performing the sqrt operation after the distance computation + * The expanded form is more efficient for sparse data. + */ +template +class l2_sqrt_expanded_distances_t + : public l2_expanded_distances_t { + public: + explicit l2_sqrt_expanded_distances_t( + const distances_config_t &config) + : l2_expanded_distances_t(config) {} + + void compute(value_t *out_dists) override { + l2_expanded_distances_t::compute(out_dists); + CUML_LOG_DEBUG("Computing Sqrt"); + // Sqrt Post-processing + value_t p = 0.5; // standard l2 + raft::linalg::unaryOp( + out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, + [p] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return powf(fabs(input), p) * neg; + }, + this->config_->stream); + } + + ~l2_sqrt_expanded_distances_t() = default; +}; + /** * Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * sqrt(sum(y_k)^2))) * The expanded form is more efficient for sparse data. @@ -169,17 +196,20 @@ class cosine_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { + CUML_LOG_DEBUG("Computing inner products"); ip_dists.compute(out_dists); value_idx *b_indices = ip_dists.b_rows_coo(); value_t *b_data = ip_dists.b_data_coo(); + CUML_LOG_DEBUG("Computing COO row index array"); raft::mr::device::buffer search_coo_rows( config_->allocator, config_->stream, config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, config_->stream); + CUML_LOG_DEBUG("Computing L2"); compute_l2( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, @@ -219,6 +249,8 @@ class hellinger_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { + CUML_LOG_DEBUG("Computing Hellinger Distance"); + // First sqrt A and B raft::linalg::unaryOp( config_->a_data, config_->a_data, config_->a_nnz, diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/lp_distance.cuh index e991224f1b..10daff1c83 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/lp_distance.cuh @@ -41,7 +41,7 @@ namespace sparse { namespace distance { template + typename product_f, typename accum_f, typename write_f> void unexpanded_lp_distances( value_t *out_dists, const distances_config_t *config_, @@ -104,6 +104,8 @@ class l1_unexpanded_distances_t : public distances_t { : config_(&config) {} void compute(value_t *out_dists) { + CUML_LOG_DEBUG("Running l1 dists"); + unexpanded_lp_distances(out_dists, config_, AbsDiff(), Sum(), AtomicAdd()); } @@ -124,10 +126,33 @@ class l2_unexpanded_distances_t : public distances_t { Sum(), AtomicAdd()); } - private: + protected: const distances_config_t *config_; }; +template +class l2_sqrt_unexpanded_distances_t + : public l2_unexpanded_distances_t { + public: + l2_sqrt_unexpanded_distances_t( + const distances_config_t &config) + : l2_unexpanded_distances_t(config) {} + + void compute(value_t *out_dists) { + l2_unexpanded_distances_t::compute(out_dists); + CUML_LOG_DEBUG("Computing Sqrt"); + // Sqrt Post-processing + value_t p = 0.5; // standard l2 + raft::linalg::unaryOp( + out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, + [p] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return powf(fabs(input), p) * neg; + }, + this->config_->stream); + } +}; + template class linf_unexpanded_distances_t : public distances_t { public: diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 3e8fa2bd6f..49eb0e5be6 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -127,8 +127,8 @@ class sparse_knn_t { size_t batch_size_index_ = 2 << 14, // approx 1M size_t batch_size_query_ = 2 << 14, raft::distance::DistanceType metric_ = - raft::distance::DistanceType::L2Expanded, - float metricArg_ = 0, bool expanded_form_ = false) + raft::distance::DistanceType::L2Expanded, + float metricArg_ = 0) : idxIndptr(idxIndptr_), idxIndices(idxIndices_), idxData(idxData_), @@ -150,8 +150,7 @@ class sparse_knn_t { batch_size_index(batch_size_index_), batch_size_query(batch_size_query_), metric(metric_), - metricArg(metricArg_), - expanded_form(expanded_form_) {} + metricArg(metricArg_) {} void run() { using namespace raft::sparse; @@ -172,26 +171,23 @@ class sparse_knn_t { * Slice CSR to rows in batch */ - raft::mr::device::buffer query_batch_indptr( - allocator, stream, query_batcher.batch_rows() + 1); + rmm::device_uvector query_batch_indptr( + query_batcher.batch_rows() + 1, stream); value_idx n_query_batch_nnz = query_batcher.get_batch_csr_indptr_nnz( query_batch_indptr.data(), stream); - raft::mr::device::buffer query_batch_indices( - allocator, stream, n_query_batch_nnz); - raft::mr::device::buffer query_batch_data(allocator, stream, - n_query_batch_nnz); + rmm::device_uvector query_batch_indices(n_query_batch_nnz, + stream); + rmm::device_uvector query_batch_data(n_query_batch_nnz, stream); query_batcher.get_batch_csr_indices_data(query_batch_indices.data(), query_batch_data.data(), stream); // A 3-partition temporary merge space to scale the batching. 2 parts for subsequent // batches and 1 space for the results of the merge, which get copied back to the top - raft::mr::device::buffer merge_buffer_indices(allocator, - stream, 0); - raft::mr::device::buffer merge_buffer_dists(allocator, stream, - 0); + rmm::device_uvector merge_buffer_indices(0, stream); + rmm::device_uvector merge_buffer_dists(0, stream); value_t *dists_merge_buffer_ptr; value_idx *indices_merge_buffer_ptr; @@ -209,11 +205,10 @@ class sparse_knn_t { /** * Slice CSR to rows in batch */ - raft::mr::device::buffer idx_batch_indptr( - allocator, stream, idx_batcher.batch_rows() + 1); - raft::mr::device::buffer idx_batch_indices(allocator, stream, - 0); - raft::mr::device::buffer idx_batch_data(allocator, stream, 0); + rmm::device_uvector idx_batch_indptr( + idx_batcher.batch_rows() + 1, stream); + rmm::device_uvector idx_batch_indices(0, stream); + rmm::device_uvector idx_batch_data(0, stream); value_idx idx_batch_nnz = idx_batcher.get_batch_csr_indptr_nnz(idx_batch_indptr.data(), stream); @@ -229,8 +224,7 @@ class sparse_knn_t { */ size_t dense_size = idx_batcher.batch_rows() * query_batcher.batch_rows(); - raft::mr::device::buffer batch_dists(allocator, stream, - dense_size); + rmm::device_uvector batch_dists(dense_size, stream); CUDA_CHECK(cudaMemset(batch_dists.data(), 0, batch_dists.size() * sizeof(value_t))); @@ -241,17 +235,13 @@ class sparse_knn_t { query_batch_indptr.data(), query_batch_indices.data(), query_batch_data.data(), batch_dists.data()); - idx_batch_indptr.release(stream); - idx_batch_indices.release(stream); - idx_batch_data.release(stream); - // Build batch indices array - raft::mr::device::buffer batch_indices(allocator, stream, - batch_dists.size()); + rmm::device_uvector batch_indices(batch_dists.size(), + stream); // populate batch indices array value_idx batch_rows = query_batcher.batch_rows(), - batch_cols = idx_batcher.batch_rows(); + batch_cols = idx_batcher.batch_rows(); iota_fill(batch_indices.data(), batch_rows, batch_cols, stream); @@ -268,8 +258,6 @@ class sparse_knn_t { batch_indices.data(), dists_merge_buffer_ptr, indices_merge_buffer_ptr); - perform_postprocessing(dists_merge_buffer_ptr, batch_rows); - value_t *dists_merge_buffer_tmp_ptr = dists_merge_buffer_ptr; value_idx *indices_merge_buffer_tmp_ptr = indices_merge_buffer_ptr; @@ -307,23 +295,6 @@ class sparse_knn_t { } } - void perform_postprocessing(value_t *dists, size_t batch_rows) { - // Perform necessary post-processing - if (metric == raft::distance::DistanceType::L2Expanded && !expanded_form) { - /** - * post-processing - */ - value_t p = 0.5; // standard l2 - raft::linalg::unaryOp( - dists, dists, batch_rows * k, - [p] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; - }, - stream); - } - } - private: void merge_batches(csr_batcher_t &idx_batcher, csr_batcher_t &query_batcher, @@ -335,13 +306,12 @@ class sparse_knn_t { id_ranges.push_back(0); id_ranges.push_back(idx_batcher.batch_start()); - raft::mr::device::buffer trans(allocator, stream, - id_ranges.size()); + rmm::device_uvector trans(id_ranges.size(), stream); raft::update_device(trans.data(), id_ranges.data(), id_ranges.size(), stream); // combine merge buffers only if there's more than 1 partition to combine - raft::spatial::knn::detail::knn_merge_parts( + MLCommon::Selection::knn_merge_parts( merge_buffer_dists, merge_buffer_indices, out_dists, out_indices, query_batcher.batch_rows(), 2, k, stream, trans.data()); } @@ -352,7 +322,7 @@ class sparse_knn_t { value_t *out_dists, value_idx *out_indices) { // populate batch indices array value_idx batch_rows = query_batcher.batch_rows(), - batch_cols = idx_batcher.batch_rows(); + batch_cols = idx_batcher.batch_rows(); // build translation buffer to shift resulting indices by the batch std::vector id_ranges; @@ -382,6 +352,7 @@ class sparse_knn_t { /** * Compute distances */ + CUML_LOG_DEBUG("Computing pairwise distances for batch"); raft::sparse::distance::distances_config_t dist_config; dist_config.b_nrows = idx_batcher.batch_rows(); dist_config.b_ncols = n_idx_cols; @@ -403,6 +374,10 @@ class sparse_knn_t { dist_config.allocator = allocator; dist_config.stream = stream; + if (raft::sparse::distance::supportedDistance.find(metric) == + raft::sparse::distance::supportedDistance.end()) + THROW("DistanceType not supported: %d", metric); + raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, metric, metricArg); } @@ -418,8 +393,6 @@ class sparse_knn_t { float metricArg; - bool expanded_form; - int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k; cusparseHandle_t cusparseHandle; @@ -453,7 +426,6 @@ class sparse_knn_t { * @param[in] batch_size_query maximum number of rows to use from query matrix per batch * @param[in] metric distance metric/measure to use * @param[in] metricArg potential argument for metric (currently unused) - * @param[in] expanded_form whether or not Lp variants should be reduced by the pth-root */ template void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, @@ -468,13 +440,13 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, size_t batch_size_index = 2 << 14, // approx 1M size_t batch_size_query = 2 << 14, raft::distance::DistanceType metric = - raft::distance::DistanceType::L2Expanded, - float metricArg = 0, bool expanded_form = false) { + raft::distance::DistanceType::L2Expanded, + float metricArg = 0) { sparse_knn_t( idxIndptr, idxIndices, idxData, idxNNZ, n_idx_rows, n_idx_cols, queryIndptr, queryIndices, queryData, queryNNZ, n_query_rows, n_query_cols, output_indices, output_dists, k, cusparseHandle, allocator, stream, - batch_size_index, batch_size_query, metric, metricArg, expanded_form) + batch_size_index, batch_size_query, metric, metricArg) .run(); } From f6694d3a54643226f85f6f3f471c1c04e2f9d9fc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 3 Mar 2021 16:57:29 -0500 Subject: [PATCH 2/6] Updating style --- .../raft/sparse/distance/l2_distance.cuh | 14 +++++++------- .../raft/sparse/distance/lp_distance.cuh | 12 ++++++------ cpp/include/raft/sparse/selection/knn.cuh | 19 +++++++++---------- 3 files changed, 22 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 3947af1114..918d91ca22 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -73,7 +73,7 @@ __global__ void compute_euclidean_warp_kernel( } template + typename expansion_f> void compute_euclidean(value_t *C, const value_t *Q_sq_norms, const value_t *R_sq_norms, value_idx n_rows, value_idx n_cols, cudaStream_t stream, @@ -84,7 +84,7 @@ void compute_euclidean(value_t *C, const value_t *Q_sq_norms, } template + typename expansion_f> void compute_l2(value_t *out, const value_idx *Q_coo_rows, const value_t *Q_data, value_idx Q_nnz, const value_idx *R_coo_rows, const value_t *R_data, @@ -172,11 +172,11 @@ class l2_sqrt_expanded_distances_t value_t p = 0.5; // standard l2 raft::linalg::unaryOp( out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, - [p] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; - }, - this->config_->stream); + [p] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return powf(fabs(input), p) * neg; + }, + this->config_->stream); } ~l2_sqrt_expanded_distances_t() = default; diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/lp_distance.cuh index 10daff1c83..5f397a4f19 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/lp_distance.cuh @@ -41,7 +41,7 @@ namespace sparse { namespace distance { template + typename product_f, typename accum_f, typename write_f> void unexpanded_lp_distances( value_t *out_dists, const distances_config_t *config_, @@ -145,11 +145,11 @@ class l2_sqrt_unexpanded_distances_t value_t p = 0.5; // standard l2 raft::linalg::unaryOp( out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, - [p] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; - }, - this->config_->stream); + [p] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return powf(fabs(input), p) * neg; + }, + this->config_->stream); } }; diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 49eb0e5be6..04a2b059fa 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -16,6 +16,8 @@ #pragma once +#include + #include #include #include @@ -34,13 +36,10 @@ #include #include -#include - #include - -#include - +#include #include +#include #include @@ -127,7 +126,7 @@ class sparse_knn_t { size_t batch_size_index_ = 2 << 14, // approx 1M size_t batch_size_query_ = 2 << 14, raft::distance::DistanceType metric_ = - raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2Expanded, float metricArg_ = 0) : idxIndptr(idxIndptr_), idxIndices(idxIndices_), @@ -241,7 +240,7 @@ class sparse_knn_t { // populate batch indices array value_idx batch_rows = query_batcher.batch_rows(), - batch_cols = idx_batcher.batch_rows(); + batch_cols = idx_batcher.batch_rows(); iota_fill(batch_indices.data(), batch_rows, batch_cols, stream); @@ -311,7 +310,7 @@ class sparse_knn_t { stream); // combine merge buffers only if there's more than 1 partition to combine - MLCommon::Selection::knn_merge_parts( + raft::spatial::knn::detail::knn_merge_parts( merge_buffer_dists, merge_buffer_indices, out_dists, out_indices, query_batcher.batch_rows(), 2, k, stream, trans.data()); } @@ -322,7 +321,7 @@ class sparse_knn_t { value_t *out_dists, value_idx *out_indices) { // populate batch indices array value_idx batch_rows = query_batcher.batch_rows(), - batch_cols = idx_batcher.batch_rows(); + batch_cols = idx_batcher.batch_rows(); // build translation buffer to shift resulting indices by the batch std::vector id_ranges; @@ -440,7 +439,7 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, size_t batch_size_index = 2 << 14, // approx 1M size_t batch_size_query = 2 << 14, raft::distance::DistanceType metric = - raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2Expanded, float metricArg = 0) { sparse_knn_t( idxIndptr, idxIndices, idxData, idxNNZ, n_idx_rows, n_idx_cols, queryIndptr, From cc4c20367fdb6a872f3b49ecc63bb10472dd1a56 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 3 Mar 2021 17:07:43 -0500 Subject: [PATCH 3/6] Removing cuml debugs --- cpp/include/raft/sparse/distance/ip_distance.cuh | 6 ------ cpp/include/raft/sparse/distance/l2_distance.cuh | 8 -------- cpp/include/raft/sparse/distance/lp_distance.cuh | 3 --- cpp/include/raft/sparse/selection/knn.cuh | 1 - 4 files changed, 18 deletions(-) diff --git a/cpp/include/raft/sparse/distance/ip_distance.cuh b/cpp/include/raft/sparse/distance/ip_distance.cuh index a832c2b6a9..90717bfc5f 100644 --- a/cpp/include/raft/sparse/distance/ip_distance.cuh +++ b/cpp/include/raft/sparse/distance/ip_distance.cuh @@ -190,18 +190,12 @@ class ip_distances_gemm_t : public ip_trans_getters_t { value_t *csr_out_data) { value_idx m = config_->a_nrows, n = config_->b_nrows, k = config_->a_ncols; - int start = raft::curTimeMillis(); - - CUDA_CHECK(cudaStreamSynchronize(config_->stream)); - CUSPARSE_CHECK(raft::sparse::cusparsecsrgemm2( config_->handle, m, n, k, &alpha, matA, config_->a_nnz, config_->a_data, config_->a_indptr, config_->a_indices, matB, config_->b_nnz, csc_data.data(), csc_indptr.data(), csc_indices.data(), NULL, matD, 0, NULL, NULL, NULL, matC, csr_out_data, csr_out_indptr, csr_out_indices, info, workspace.data(), config_->stream)); - - CUDA_CHECK(cudaStreamSynchronize(config_->stream)); } void transpose_b() { diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 918d91ca22..9fe2871592 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -122,20 +122,17 @@ class l2_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { - CUML_LOG_DEBUG("Computing inner products"); ip_dists.compute(out_dists); value_idx *b_indices = ip_dists.b_rows_coo(); value_t *b_data = ip_dists.b_data_coo(); - CUML_LOG_DEBUG("Computing COO row index array"); raft::mr::device::buffer search_coo_rows( config_->allocator, config_->stream, config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, config_->stream); - CUML_LOG_DEBUG("Computing L2"); compute_l2( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, @@ -167,7 +164,6 @@ class l2_sqrt_expanded_distances_t void compute(value_t *out_dists) override { l2_expanded_distances_t::compute(out_dists); - CUML_LOG_DEBUG("Computing Sqrt"); // Sqrt Post-processing value_t p = 0.5; // standard l2 raft::linalg::unaryOp( @@ -196,20 +192,17 @@ class cosine_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { - CUML_LOG_DEBUG("Computing inner products"); ip_dists.compute(out_dists); value_idx *b_indices = ip_dists.b_rows_coo(); value_t *b_data = ip_dists.b_data_coo(); - CUML_LOG_DEBUG("Computing COO row index array"); raft::mr::device::buffer search_coo_rows( config_->allocator, config_->stream, config_->a_nnz); raft::sparse::convert::csr_to_coo(config_->a_indptr, config_->a_nrows, search_coo_rows.data(), config_->a_nnz, config_->stream); - CUML_LOG_DEBUG("Computing L2"); compute_l2( out_dists, search_coo_rows.data(), config_->a_data, config_->a_nnz, b_indices, b_data, config_->b_nnz, config_->a_nrows, config_->b_nrows, @@ -249,7 +242,6 @@ class hellinger_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { - CUML_LOG_DEBUG("Computing Hellinger Distance"); // First sqrt A and B raft::linalg::unaryOp( diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/lp_distance.cuh index 5f397a4f19..732978c639 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/lp_distance.cuh @@ -104,8 +104,6 @@ class l1_unexpanded_distances_t : public distances_t { : config_(&config) {} void compute(value_t *out_dists) { - CUML_LOG_DEBUG("Running l1 dists"); - unexpanded_lp_distances(out_dists, config_, AbsDiff(), Sum(), AtomicAdd()); } @@ -140,7 +138,6 @@ class l2_sqrt_unexpanded_distances_t void compute(value_t *out_dists) { l2_unexpanded_distances_t::compute(out_dists); - CUML_LOG_DEBUG("Computing Sqrt"); // Sqrt Post-processing value_t p = 0.5; // standard l2 raft::linalg::unaryOp( diff --git a/cpp/include/raft/sparse/selection/knn.cuh b/cpp/include/raft/sparse/selection/knn.cuh index 04a2b059fa..b309840d80 100644 --- a/cpp/include/raft/sparse/selection/knn.cuh +++ b/cpp/include/raft/sparse/selection/knn.cuh @@ -351,7 +351,6 @@ class sparse_knn_t { /** * Compute distances */ - CUML_LOG_DEBUG("Computing pairwise distances for batch"); raft::sparse::distance::distances_config_t dist_config; dist_config.b_nrows = idx_batcher.batch_rows(); dist_config.b_ncols = n_idx_cols; From ba57d5b1be2f5e046015fe7799f8c9eacefe2ec1 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 3 Mar 2021 17:08:12 -0500 Subject: [PATCH 4/6] Fixing style --- cpp/include/raft/sparse/distance/l2_distance.cuh | 1 - 1 file changed, 1 deletion(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 9fe2871592..eb49d1089c 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -242,7 +242,6 @@ class hellinger_expanded_distances_t : public distances_t { ip_dists(config) {} void compute(value_t *out_dists) { - // First sqrt A and B raft::linalg::unaryOp( config_->a_data, config_->a_data, config_->a_nnz, From 420c467e211e8a109f002f89a680131a125e7f78 Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 4 Mar 2021 11:12:05 -0500 Subject: [PATCH 5/6] updating gtests --- .../raft/sparse/distance/l2_distance.cuh | 2 +- cpp/test/sparse/dist_coo_spmv.cu | 16 ++-- cpp/test/sparse/dist_csr_spmv.cu | 12 ++- cpp/test/sparse/distance.cu | 89 ++++++++++--------- cpp/test/sparse/knn.cu | 26 ++---- 5 files changed, 73 insertions(+), 72 deletions(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index eb49d1089c..5cf290faa2 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -16,7 +16,7 @@ #pragma once -#include +#include #include #include diff --git a/cpp/test/sparse/dist_coo_spmv.cu b/cpp/test/sparse/dist_coo_spmv.cu index a841da661d..6e3f3b5038 100644 --- a/cpp/test/sparse/dist_coo_spmv.cu +++ b/cpp/test/sparse/dist_coo_spmv.cu @@ -155,6 +155,7 @@ class SparseDistanceCOOSPMVTest CUDA_CHECK(cudaStreamCreate(&stream)); CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); + CUSPARSE_CHECK(cusparseSetStream(cusparseHandle, stream)); make_data(); @@ -225,7 +226,8 @@ const std::vector> inputs_i32_f = { {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0}, - raft::distance::DistanceType::InnerProduct}, + raft::distance::DistanceType::InnerProduct, + 0.0}, {2, {0, 2, 4, 6, 8}, {0, 1, 0, 1, 0, 1, 0, 1}, // indices @@ -249,7 +251,8 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Unexpanded}, + raft::distance::DistanceType::L2Unexpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -362,7 +365,8 @@ const std::vector> inputs_i32_f = { 6.903282911791188, 7.0, 0.0}, - raft::distance::DistanceType::Canberra}, + raft::distance::DistanceType::Canberra, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -589,7 +593,8 @@ const std::vector> inputs_i32_f = { 0.5079750812968089, 0.8429599432532096, 0.0}, - raft::distance::DistanceType::Linf}, + raft::distance::DistanceType::Linf, + 0.0}, {4, {0, 1, 1, 2, 4}, @@ -614,7 +619,8 @@ const std::vector> inputs_i32_f = { 0.84454, 0.0, }, - raft::distance::DistanceType::L1} + raft::distance::DistanceType::L1, + 0.0} }; diff --git a/cpp/test/sparse/dist_csr_spmv.cu b/cpp/test/sparse/dist_csr_spmv.cu index 2405909c40..c32748a04e 100644 --- a/cpp/test/sparse/dist_csr_spmv.cu +++ b/cpp/test/sparse/dist_csr_spmv.cu @@ -229,7 +229,8 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Unexpanded}, + raft::distance::DistanceType::L2Unexpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -342,7 +343,8 @@ const std::vector> inputs_i32_f = { 6.903282911791188, 7.0, 0.0}, - raft::distance::DistanceType::Canberra}, + raft::distance::DistanceType::Canberra, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -569,7 +571,8 @@ const std::vector> inputs_i32_f = { 0.5079750812968089, 0.8429599432532096, 0.0}, - raft::distance::DistanceType::Linf}, + raft::distance::DistanceType::Linf, + 0.0}, {4, {0, 1, 1, 2, 4}, @@ -594,7 +597,8 @@ const std::vector> inputs_i32_f = { 0.84454, 0.0, }, - raft::distance::DistanceType::L1} + raft::distance::DistanceType::L1, + 0.0} }; diff --git a/cpp/test/sparse/distance.cu b/cpp/test/sparse/distance.cu index 53e8838b65..b103486b96 100644 --- a/cpp/test/sparse/distance.cu +++ b/cpp/test/sparse/distance.cu @@ -88,6 +88,7 @@ class SparseDistanceTest CUDA_CHECK(cudaStreamCreate(&stream)); CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); + CUSPARSE_CHECK(cusparseSetStream(cusparseHandle, stream)); make_data(); @@ -127,15 +128,9 @@ class SparseDistanceTest } void compare() { - // skip Hellinger test due to sporadic CI issue - // https://github.com/rapidsai/cuml/issues/3477 - if (params.metric == raft::distance::DistanceType::HellingerExpanded) { - GTEST_SKIP(); - } else { - ASSERT_TRUE(devArrMatch(out_dists_ref, out_dists, - params.out_dists_ref_h.size(), - CompareApprox(1e-3))); - } + ASSERT_TRUE(devArrMatch(out_dists_ref, out_dists, + params.out_dists_ref_h.size(), + CompareApprox(1e-3))); } protected: @@ -176,14 +171,16 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Expanded}, + raft::distance::DistanceType::L2Expanded, + 0.0}, {2, {0, 2, 4, 6, 8}, {0, 1, 0, 1, 0, 1, 0, 1}, {1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f, 1.0f, 2.0f}, {5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0, 5.0}, - raft::distance::DistanceType::InnerProduct}, + raft::distance::DistanceType::InnerProduct, + 0.0}, {2, {0, 2, 4, 6, 8}, {0, 1, 0, 1, 0, 1, 0, 1}, // indices @@ -207,7 +204,8 @@ const std::vector> inputs_i32_f = { 1832.0, 0.0, }, - raft::distance::DistanceType::L2Unexpanded}, + raft::distance::DistanceType::L2Unexpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -237,7 +235,8 @@ const std::vector> inputs_i32_f = { 0.67676228, 0.24558392, 0.76064776, 0.51360432, 0., 1., 0.76978799, 0.78021386, 1., 0.84923694, 0.73155632, 0.99166225, 0.61547536, 0.68185144, 1., 0.}, - raft::distance::DistanceType::CosineExpanded}, + raft::distance::DistanceType::CosineExpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -347,7 +346,8 @@ const std::vector> inputs_i32_f = { 0.75, 1.0, 0.0}, - raft::distance::DistanceType::JaccardExpanded}, + raft::distance::DistanceType::JaccardExpanded, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -460,7 +460,8 @@ const std::vector> inputs_i32_f = { 6.903282911791188, 7.0, 0.0}, - raft::distance::DistanceType::Canberra}, + raft::distance::DistanceType::Canberra, + 0.0}, {10, {0, 5, 11, 15, 20, 27, 32, 36, 43, 47, 50}, @@ -687,33 +688,10 @@ const std::vector> inputs_i32_f = { 0.5079750812968089, 0.8429599432532096, 0.0}, - raft::distance::DistanceType::Linf}, + raft::distance::DistanceType::Linf, + 0.0}, - {4, - {0, 1, 1, 2, 4}, - {3, 2, 0, 1}, // indices - {0.99296, 0.42180, 0.11687, 0.305869}, - { - // dense output - 0.0, - 0.99296, - 1.41476, - 1.415707, - 0.99296, - 0.0, - 0.42180, - 0.42274, - 1.41476, - 0.42180, - 0.0, - 0.84454, - 1.41570, - 0.42274, - 0.84454, - 0.0, - }, - raft::distance::DistanceType::L1}, - {10, + {15, {0, 5, 8, 9, 15, 20, 26, 31, 34, 38, 45}, {0, 1, 5, 6, 9, 1, 4, 14, 7, 3, 4, 7, 9, 11, 14, 0, 3, 7, 8, 12, 0, 2, 5, 7, 8, 14, 4, 9, 10, 11, @@ -752,7 +730,34 @@ const std::vector> inputs_i32_f = { 1.00000000e+00, 8.05419635e-01, 9.53789212e-01, 8.07933016e-01, 7.40428532e-01, 7.95485011e-01, 8.51370877e-01, 1.49011612e-08}, // Dataset is L1 normalized into pdfs - raft::distance::DistanceType::HellingerExpanded}}; + raft::distance::DistanceType::HellingerExpanded, + 0.0}, + + {4, + {0, 1, 1, 2, 4}, + {3, 2, 0, 1}, // indices + {0.99296, 0.42180, 0.11687, 0.305869}, + { + // dense output + 0.0, + 0.99296, + 1.41476, + 1.415707, + 0.99296, + 0.0, + 0.42180, + 0.42274, + 1.41476, + 0.42180, + 0.0, + 0.84454, + 1.41570, + 0.42274, + 0.84454, + 0.0, + }, + raft::distance::DistanceType::L1, + 0.0}}; typedef SparseDistanceTest SparseDistanceTestF; TEST_P(SparseDistanceTestF, Result) { compare(); } diff --git a/cpp/test/sparse/knn.cu b/cpp/test/sparse/knn.cu index 0f773b9fee..4759eebe4b 100644 --- a/cpp/test/sparse/knn.cu +++ b/cpp/test/sparse/knn.cu @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include "../test_utils.h" @@ -49,7 +51,7 @@ struct SparseKNNInputs { int batch_size_query = 2; raft::distance::DistanceType metric = - raft::distance::DistanceType::L2Expanded; + raft::distance::DistanceType::L2SqrtExpanded; }; template @@ -67,14 +69,10 @@ class SparseKNNTest std::vector indices_h = params.indices_h; std::vector data_h = params.data_h; - printf("Allocating input\n"); - allocate(indptr, indptr_h.size()); allocate(indices, indices_h.size()); allocate(data, data_h.size()); - printf("Updating device\n"); - update_device(indptr, indptr_h.data(), indptr_h.size(), stream); update_device(indices, indices_h.data(), indices_h.size(), stream); update_device(data, data_h.data(), data_h.size(), stream); @@ -82,23 +80,16 @@ class SparseKNNTest std::vector out_dists_ref_h = params.out_dists_ref_h; std::vector out_indices_ref_h = params.out_indices_ref_h; - printf("Allocating ref output\n"); allocate(out_indices_ref, out_indices_ref_h.size()); allocate(out_dists_ref, out_dists_ref_h.size()); - printf("Updating device\n"); - update_device(out_indices_ref, out_indices_ref_h.data(), out_indices_ref_h.size(), stream); update_device(out_dists_ref, out_dists_ref_h.data(), out_dists_ref_h.size(), stream); - printf("Allocating final output\n"); - allocate(out_dists, n_rows * k); allocate(out_indices, n_rows * k); - - printf("Done.\n"); } void SetUp() override { @@ -106,7 +97,6 @@ class SparseKNNTest ::testing::TestWithParam>::GetParam(); std::shared_ptr alloc( new raft::mr::device::default_allocator); - CUDA_CHECK(cudaStreamCreate(&stream)); CUSPARSE_CHECK(cusparseCreate(&cusparseHandle)); @@ -115,12 +105,8 @@ class SparseKNNTest nnz = params.indices_h.size(); k = params.k; - printf("Making data\n"); - make_data(); - printf("About to run kselect\n"); - raft::sparse::selection::brute_force_knn( indptr, indices, data, nnz, n_rows, params.n_cols, indptr, indices, data, nnz, n_rows, params.n_cols, out_indices, out_dists, k, cusparseHandle, @@ -128,11 +114,11 @@ class SparseKNNTest params.metric); CUDA_CHECK(cudaStreamSynchronize(stream)); - - printf("Executed k-select"); } void TearDown() override { + CUDA_CHECK(cudaStreamSynchronize(stream)); + CUDA_CHECK(cudaFree(indptr)); CUDA_CHECK(cudaFree(indices)); CUDA_CHECK(cudaFree(data)); @@ -181,7 +167,7 @@ const std::vector> inputs_i32_f = { 2, 2, 2, - raft::distance::DistanceType::L2Expanded}}; + raft::distance::DistanceType::L2SqrtExpanded}}; typedef SparseKNNTest SparseKNNTestF; TEST_P(SparseKNNTestF, Result) { compare(); } INSTANTIATE_TEST_CASE_P(SparseKNNTest, SparseKNNTestF, From 729b121cc614904a0366e743f050c8351e094eeb Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Thu, 4 Mar 2021 15:00:54 -0500 Subject: [PATCH 6/6] Using sqrt for l2 computations --- cpp/include/raft/sparse/distance/l2_distance.cuh | 5 ++--- cpp/include/raft/sparse/distance/lp_distance.cuh | 5 ++--- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/cpp/include/raft/sparse/distance/l2_distance.cuh b/cpp/include/raft/sparse/distance/l2_distance.cuh index 5cf290faa2..98187576fa 100644 --- a/cpp/include/raft/sparse/distance/l2_distance.cuh +++ b/cpp/include/raft/sparse/distance/l2_distance.cuh @@ -165,12 +165,11 @@ class l2_sqrt_expanded_distances_t void compute(value_t *out_dists) override { l2_expanded_distances_t::compute(out_dists); // Sqrt Post-processing - value_t p = 0.5; // standard l2 raft::linalg::unaryOp( out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, - [p] __device__(value_t input) { + [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; + return sqrt(abs(input) * neg); }, this->config_->stream); } diff --git a/cpp/include/raft/sparse/distance/lp_distance.cuh b/cpp/include/raft/sparse/distance/lp_distance.cuh index 732978c639..e524d87b7c 100644 --- a/cpp/include/raft/sparse/distance/lp_distance.cuh +++ b/cpp/include/raft/sparse/distance/lp_distance.cuh @@ -139,12 +139,11 @@ class l2_sqrt_unexpanded_distances_t void compute(value_t *out_dists) { l2_unexpanded_distances_t::compute(out_dists); // Sqrt Post-processing - value_t p = 0.5; // standard l2 raft::linalg::unaryOp( out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, - [p] __device__(value_t input) { + [] __device__(value_t input) { int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; + return sqrt(abs(input) * neg); }, this->config_->stream); }