diff --git a/cpp/include/cuml/neighbors/knn.hpp b/cpp/include/cuml/neighbors/knn.hpp index 3effa58a6f..bb918f94ee 100644 --- a/cpp/include/cuml/neighbors/knn.hpp +++ b/cpp/include/cuml/neighbors/knn.hpp @@ -18,28 +18,11 @@ #include #include +#include #include #include namespace ML { - -enum MetricType { - METRIC_INNER_PRODUCT = 0, - METRIC_L2, - METRIC_L1, - METRIC_Linf, - METRIC_Lp, - - METRIC_Canberra = 20, - METRIC_BrayCurtis, - METRIC_JensenShannon, - - METRIC_Cosine = 100, - METRIC_Correlation, - METRIC_Jaccard, - METRIC_Hellinger -}; - struct knnIndex { faiss::gpu::StandardGpuResources *gpu_res; faiss::gpu::GpuIndex *index; @@ -102,20 +85,19 @@ struct IVFSQParam : IVFParam { * default * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. - * @param[in] expanded should lp-based distances be returned in their expanded - * form (e.g., without raising to the 1/p power). */ void brute_force_knn(raft::handle_t &handle, std::vector &input, std::vector &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex = false, bool rowMajorQuery = false, - MetricType metric = MetricType::METRIC_L2, - float metric_arg = 2.0f, bool expanded = false); + raft::distance::DistanceType metric = + raft::distance::DistanceType::L2Expanded, + float metric_arg = 2.0f); void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, ML::knnIndexParam *params, int D, - ML::MetricType metric, float metricArg, - float *index_items, int n); + raft::distance::DistanceType metric, + float metricArg, float *index_items, int n); void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k, float *distances, int64_t *labels); diff --git a/cpp/include/cuml/neighbors/knn_api.h b/cpp/include/cuml/neighbors/knn_api.h index 05420e4742..d36dd7781b 100644 --- a/cpp/include/cuml/neighbors/knn_api.h +++ b/cpp/include/cuml/neighbors/knn_api.h @@ -42,8 +42,8 @@ extern "C" { * @param[in] rowMajorIndex is the index array in row major layout? * @param[in] rowMajorQuery is the query array in row major layout? * @param[in] metric_type the type of distance metric to use. This corresponds - * to the value in the ML::MetricType enum. Default is - * Euclidean (L2). + * to the value in the raft::distance::DistanceType enum. + * Default is Euclidean (L2). * @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This * is ignored if the metric_type is not Minkowski. * @param[in] expanded should lp-based distances be returned in their expanded diff --git a/cpp/include/cuml/neighbors/knn_sparse.hpp b/cpp/include/cuml/neighbors/knn_sparse.hpp index bc702aa4f7..6bb78c7157 100644 --- a/cpp/include/cuml/neighbors/knn_sparse.hpp +++ b/cpp/include/cuml/neighbors/knn_sparse.hpp @@ -20,6 +20,7 @@ #include +#include #include namespace ML { @@ -36,7 +37,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, float *output_dists, int k, size_t batch_size_index = DEFAULT_BATCH_SIZE, size_t batch_size_query = DEFAULT_BATCH_SIZE, - ML::MetricType metric = ML::MetricType::METRIC_L2, - float metricArg = 0, bool expanded_form = false); + raft::distance::DistanceType metric = + raft::distance::DistanceType::L2Expanded, + float metricArg = 0); }; // end namespace Sparse }; // end namespace ML diff --git a/cpp/src/knn/knn.cu b/cpp/src/knn/knn.cu index 4eafe85820..64c3f0b9e5 100644 --- a/cpp/src/knn/knn.cu +++ b/cpp/src/knn/knn.cu @@ -33,8 +33,8 @@ namespace ML { void brute_force_knn(raft::handle_t &handle, std::vector &input, std::vector &sizes, int D, float *search_items, int n, int64_t *res_I, float *res_D, int k, bool rowMajorIndex, - bool rowMajorQuery, MetricType metric, float metric_arg, - bool expanded) { + bool rowMajorQuery, raft::distance::DistanceType metric, + float metric_arg) { ASSERT(input.size() == sizes.size(), "input and sizes vectors must be the same size"); @@ -44,13 +44,13 @@ void brute_force_knn(raft::handle_t &handle, std::vector &input, input, sizes, D, search_items, n, res_I, res_D, k, handle.get_device_allocator(), handle.get_stream(), int_streams.data(), handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery, nullptr, - metric, metric_arg, expanded); + metric, metric_arg); } void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index, ML::knnIndexParam *params, int D, - ML::MetricType metric, float metricArg, - float *index_items, int n) { + raft::distance::DistanceType metric, + float metricArg, float *index_items, int n) { MLCommon::Selection::approx_knn_build_index( index, params, D, metric, metricArg, index_items, n, handle.get_stream()); } diff --git a/cpp/src/knn/knn_api.cpp b/cpp/src/knn/knn_api.cpp index 547fd64629..7ae1cec442 100644 --- a/cpp/src/knn/knn_api.cpp +++ b/cpp/src/knn/knn_api.cpp @@ -57,6 +57,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes, cumlError_t status; raft::handle_t *handle_ptr; std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); + raft::distance::DistanceType metric_distance_type = + static_cast(metric_type); std::vector int_streams = handle_ptr->get_internal_streams(); @@ -71,7 +73,7 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes, try { ML::brute_force_knn(*handle_ptr, input_vec, sizes_vec, D, search_items, n, res_I, res_D, k, rowMajorIndex, rowMajorQuery, - (ML::MetricType)metric_type, metric_arg, expanded); + metric_distance_type, metric_arg); } catch (...) { status = CUML_ERROR_UNKNOWN; } diff --git a/cpp/src/knn/knn_sparse.cu b/cpp/src/knn/knn_sparse.cu index 684150e2d3..52a847358a 100644 --- a/cpp/src/knn/knn_sparse.cu +++ b/cpp/src/knn/knn_sparse.cu @@ -32,8 +32,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, int n_query_rows, int n_query_cols, int *output_indices, float *output_dists, int k, size_t batch_size_index, // approx 1M - size_t batch_size_query, ML::MetricType metric, - float metricArg, bool expanded_form) { + size_t batch_size_query, + raft::distance::DistanceType metric, float metricArg) { auto d_alloc = handle.get_device_allocator(); cusparseHandle_t cusparse_handle = handle.get_cusparse_handle(); cudaStream_t stream = handle.get_stream(); @@ -42,8 +42,7 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, idx_indptr, idx_indices, idx_data, idx_nnz, n_idx_rows, n_idx_cols, query_indptr, query_indices, query_data, query_nnz, n_query_rows, n_query_cols, output_indices, output_dists, k, cusparse_handle, d_alloc, - stream, batch_size_index, batch_size_query, metric, metricArg, - expanded_form); + stream, batch_size_index, batch_size_query, metric, metricArg); } }; // namespace Sparse }; // namespace ML diff --git a/cpp/src/tsne/distances.cuh b/cpp/src/tsne/distances.cuh index 4baa50aaea..423106363d 100644 --- a/cpp/src/tsne/distances.cuh +++ b/cpp/src/tsne/distances.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -89,7 +89,7 @@ void get_distances(const raft::handle_t &handle, k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, handle.get_cusparse_handle(), handle.get_device_allocator(), stream, ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - ML::MetricType::METRIC_L2); + raft::distance::DistanceType::L2Expanded); } // sparse, int64 diff --git a/cpp/src/umap/knn_graph/algo.cuh b/cpp/src/umap/knn_graph/algo.cuh index ae27e74d53..bc1c2c3213 100644 --- a/cpp/src/umap/knn_graph/algo.cuh +++ b/cpp/src/umap/knn_graph/algo.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include #include @@ -91,7 +92,7 @@ void launcher(const raft::handle_t &handle, inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors, handle.get_cusparse_handle(), d_alloc, stream, ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, - ML::MetricType::METRIC_L2); + raft::distance::DistanceType::L2Expanded); } template <> diff --git a/cpp/src_prims/selection/knn.cuh b/cpp/src_prims/selection/knn.cuh index 032d6d1c92..25db378016 100644 --- a/cpp/src_prims/selection/knn.cuh +++ b/cpp/src_prims/selection/knn.cuh @@ -37,6 +37,7 @@ #include #include +#include #include "processing.cuh" #include @@ -186,14 +187,37 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK, inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); } -inline faiss::MetricType build_faiss_metric(ML::MetricType metric) { +inline faiss::MetricType build_faiss_metric( + raft::distance::DistanceType metric) { switch (metric) { - case ML::MetricType::METRIC_Cosine: + case raft::distance::DistanceType::CosineExpanded: return faiss::MetricType::METRIC_INNER_PRODUCT; - case ML::MetricType::METRIC_Correlation: + case raft::distance::DistanceType::CorrelationExpanded: return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::L2Expanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2Unexpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2SqrtExpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L2SqrtUnexpanded: + return faiss::MetricType::METRIC_L2; + case raft::distance::DistanceType::L1: + return faiss::MetricType::METRIC_L1; + case raft::distance::DistanceType::InnerProduct: + return faiss::MetricType::METRIC_INNER_PRODUCT; + case raft::distance::DistanceType::LpUnexpanded: + return faiss::MetricType::METRIC_Lp; + case raft::distance::DistanceType::Linf: + return faiss::MetricType::METRIC_Linf; + case raft::distance::DistanceType::Canberra: + return faiss::MetricType::METRIC_Canberra; + case raft::distance::DistanceType::BrayCurtis: + return faiss::MetricType::METRIC_BrayCurtis; + case raft::distance::DistanceType::JensenShannon: + return faiss::MetricType::METRIC_JensenShannon; default: - return (faiss::MetricType)metric; + THROW("MetricType not supported: %d", metric); } } @@ -219,7 +243,8 @@ inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype( template void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params, - IntType D, ML::MetricType metric, + IntType D, + raft::distance::DistanceType metric, IntType n) { faiss::gpu::GpuIndexIVFFlatConfig config; config.device = index->device; @@ -232,7 +257,9 @@ void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params, template void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params, - IntType D, ML::MetricType metric, IntType n) { + IntType D, + raft::distance::DistanceType metric, + IntType n) { faiss::gpu::GpuIndexIVFPQConfig config; config.device = index->device; config.usePrecomputedTables = params->usePrecomputedTables; @@ -246,7 +273,9 @@ void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params, template void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params, - IntType D, ML::MetricType metric, IntType n) { + IntType D, + raft::distance::DistanceType metric, + IntType n) { faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; config.device = index->device; faiss::MetricType faiss_metric = build_faiss_metric(metric); @@ -262,8 +291,8 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params, template void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params, - IntType D, ML::MetricType metric, float metricArg, - float *index_items, IntType n, + IntType D, raft::distance::DistanceType metric, + float metricArg, float *index_items, IntType n, cudaStream_t userStream) { int device; CUDA_CHECK(cudaGetDevice(&device)); @@ -328,9 +357,8 @@ void approx_knn_search(ML::knnIndex *index, IntType n, const float *x, * @param[in] rowMajorQuery are the query array in row-major layout? * @param[in] translations translation ids for indices when index rows represent * non-contiguous partitions - * @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean) + * @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded) * @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm - * @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root */ template void brute_force_knn(std::vector &input, std::vector &sizes, @@ -342,8 +370,9 @@ void brute_force_knn(std::vector &input, std::vector &sizes, int n_int_streams = 0, bool rowMajorIndex = true, bool rowMajorQuery = true, std::vector *translations = nullptr, - ML::MetricType metric = ML::MetricType::METRIC_L2, - float metricArg = 0, bool expanded_form = false) { + raft::distance::DistanceType metric = + raft::distance::DistanceType::L2Expanded, + float metricArg = 0) { ASSERT(input.size() == sizes.size(), "input and sizes vectors should be the same size"); @@ -452,9 +481,9 @@ void brute_force_knn(std::vector &input, std::vector &sizes, } // Perform necessary post-processing - if ((m == faiss::MetricType::METRIC_L2 || - m == faiss::MetricType::METRIC_Lp) && - !expanded_form) { + if (metric == raft::distance::DistanceType::L2SqrtExpanded || + metric == raft::distance::DistanceType::L2SqrtUnexpanded || + metric == raft::distance::DistanceType::LpUnexpanded) { /** * post-processing */ diff --git a/cpp/src_prims/selection/processing.cuh b/cpp/src_prims/selection/processing.cuh index 4bc05a0362..ccb4eeebb0 100644 --- a/cpp/src_prims/selection/processing.cuh +++ b/cpp/src_prims/selection/processing.cuh @@ -162,17 +162,17 @@ class DefaultMetricProcessor : public MetricProcessor { template inline std::unique_ptr> create_processor( - ML::MetricType metric, int n, int D, int k, bool rowMajorQuery, + raft::distance::DistanceType metric, int n, int D, int k, bool rowMajorQuery, cudaStream_t userStream, std::shared_ptr allocator) { MetricProcessor *mp = nullptr; switch (metric) { - case ML::MetricType::METRIC_Cosine: + case raft::distance::DistanceType::CosineExpanded: mp = new CosineMetricProcessor(n, D, k, rowMajorQuery, userStream, allocator); break; - case ML::MetricType::METRIC_Correlation: + case raft::distance::DistanceType::CorrelationExpanded: mp = new CorrelationMetricProcessor(n, D, k, rowMajorQuery, userStream, allocator); break; diff --git a/cpp/src_prims/sparse/distance/distance.cuh b/cpp/src_prims/sparse/distance/distance.cuh index dd697edf6c..f7e0f1f629 100644 --- a/cpp/src_prims/sparse/distance/distance.cuh +++ b/cpp/src_prims/sparse/distance/distance.cuh @@ -17,6 +17,7 @@ #pragma once #include +#include #include #include @@ -42,6 +43,20 @@ namespace raft { namespace sparse { namespace distance { +static const std::unordered_set supportedDistance{ + raft::distance::DistanceType::L2Expanded, + raft::distance::DistanceType::L2Unexpanded, + raft::distance::DistanceType::L2SqrtExpanded, + raft::distance::DistanceType::L2SqrtUnexpanded, + raft::distance::DistanceType::InnerProduct, + raft::distance::DistanceType::L1, + raft::distance::DistanceType::Canberra, + raft::distance::DistanceType::Linf, + raft::distance::DistanceType::LpUnexpanded, + raft::distance::DistanceType::JaccardExpanded, + raft::distance::DistanceType::CosineExpanded, + raft::distance::DistanceType::HellingerExpanded}; + /** * Compute pairwise distances between A and B, using the provided * input configuration and distance function. @@ -60,12 +75,20 @@ void pairwiseDistance(value_t *out, case raft::distance::DistanceType::L2Expanded: l2_expanded_distances_t(input_config).compute(out); break; + case raft::distance::DistanceType::L2SqrtExpanded: + l2_sqrt_expanded_distances_t(input_config) + .compute(out); + break; case raft::distance::DistanceType::InnerProduct: ip_distances_t(input_config).compute(out); break; case raft::distance::DistanceType::L2Unexpanded: l2_unexpanded_distances_t(input_config).compute(out); break; + case raft::distance::DistanceType::L2SqrtUnexpanded: + l2_sqrt_unexpanded_distances_t(input_config) + .compute(out); + break; case raft::distance::DistanceType::L1: l1_unexpanded_distances_t(input_config).compute(out); break; diff --git a/cpp/src_prims/sparse/distance/l2_distance.cuh b/cpp/src_prims/sparse/distance/l2_distance.cuh index 34bb9128f3..c0f6ae4ba1 100644 --- a/cpp/src_prims/sparse/distance/l2_distance.cuh +++ b/cpp/src_prims/sparse/distance/l2_distance.cuh @@ -149,12 +149,41 @@ class l2_expanded_distances_t : public distances_t { ~l2_expanded_distances_t() = default; - private: + protected: const distances_config_t *config_; raft::mr::device::buffer workspace; ip_distances_t ip_dists; }; +/** + * L2 sqrt distance performing the sqrt operation after the distance computation + * The expanded form is more efficient for sparse data. + */ +template +class l2_sqrt_expanded_distances_t + : public l2_expanded_distances_t { + public: + explicit l2_sqrt_expanded_distances_t( + const distances_config_t &config) + : l2_expanded_distances_t(config) {} + + void compute(value_t *out_dists) override { + l2_expanded_distances_t::compute(out_dists); + CUML_LOG_DEBUG("Computing Sqrt"); + // Sqrt Post-processing + value_t p = 0.5; // standard l2 + raft::linalg::unaryOp( + out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, + [p] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return powf(fabs(input), p) * neg; + }, + this->config_->stream); + } + + ~l2_sqrt_expanded_distances_t() = default; +}; + /** * Cosine distance using the expanded form: 1 - ( sum(x_k * y_k) / (sqrt(sum(x_k)^2) * sqrt(sum(y_k)^2))) * The expanded form is more efficient for sparse data. diff --git a/cpp/src_prims/sparse/distance/lp_distance.cuh b/cpp/src_prims/sparse/distance/lp_distance.cuh index 1bb23ea784..b22abff489 100644 --- a/cpp/src_prims/sparse/distance/lp_distance.cuh +++ b/cpp/src_prims/sparse/distance/lp_distance.cuh @@ -126,10 +126,33 @@ class l2_unexpanded_distances_t : public distances_t { Sum(), AtomicAdd()); } - private: + protected: const distances_config_t *config_; }; +template +class l2_sqrt_unexpanded_distances_t + : public l2_unexpanded_distances_t { + public: + l2_sqrt_unexpanded_distances_t( + const distances_config_t &config) + : l2_unexpanded_distances_t(config) {} + + void compute(value_t *out_dists) { + l2_unexpanded_distances_t::compute(out_dists); + CUML_LOG_DEBUG("Computing Sqrt"); + // Sqrt Post-processing + value_t p = 0.5; // standard l2 + raft::linalg::unaryOp( + out_dists, out_dists, this->config_->a_nrows * this->config_->b_nrows, + [p] __device__(value_t input) { + int neg = input < 0 ? -1 : 1; + return powf(fabs(input), p) * neg; + }, + this->config_->stream); + } +}; + template class linf_unexpanded_distances_t : public distances_t { public: diff --git a/cpp/src_prims/sparse/selection/knn.cuh b/cpp/src_prims/sparse/selection/knn.cuh index 242f86a7b5..6de9908ab0 100644 --- a/cpp/src_prims/sparse/selection/knn.cuh +++ b/cpp/src_prims/sparse/selection/knn.cuh @@ -121,8 +121,9 @@ class sparse_knn_t { cudaStream_t stream_, size_t batch_size_index_ = 2 << 14, // approx 1M size_t batch_size_query_ = 2 << 14, - ML::MetricType metric_ = ML::MetricType::METRIC_L2, - float metricArg_ = 0, bool expanded_form_ = false) + raft::distance::DistanceType metric_ = + raft::distance::DistanceType::L2Expanded, + float metricArg_ = 0) : idxIndptr(idxIndptr_), idxIndices(idxIndices_), idxData(idxData_), @@ -144,8 +145,7 @@ class sparse_knn_t { batch_size_index(batch_size_index_), batch_size_query(batch_size_query_), metric(metric_), - metricArg(metricArg_), - expanded_form(expanded_form_) {} + metricArg(metricArg_) {} void run() { using namespace raft::sparse; @@ -253,8 +253,6 @@ class sparse_knn_t { batch_indices.data(), dists_merge_buffer_ptr, indices_merge_buffer_ptr); - perform_postprocessing(dists_merge_buffer_ptr, batch_rows); - value_t *dists_merge_buffer_tmp_ptr = dists_merge_buffer_ptr; value_idx *indices_merge_buffer_tmp_ptr = indices_merge_buffer_ptr; @@ -292,23 +290,6 @@ class sparse_knn_t { } } - void perform_postprocessing(value_t *dists, size_t batch_rows) { - // Perform necessary post-processing - if (metric == ML::MetricType::METRIC_L2 && !expanded_form) { - /** - * post-processing - */ - value_t p = 0.5; // standard l2 - raft::linalg::unaryOp( - dists, dists, batch_rows * k, - [p] __device__(value_t input) { - int neg = input < 0 ? -1 : 1; - return powf(fabs(input), p) * neg; - }, - stream); - } - } - private: void merge_batches(csr_batcher_t &idx_batcher, csr_batcher_t &query_batcher, @@ -348,50 +329,13 @@ class sparse_knn_t { value_idx n_neighbors = min(k, batch_cols); bool ascending = true; - if (metric == ML::MetricType::METRIC_INNER_PRODUCT) ascending = false; + if (metric == raft::distance::DistanceType::InnerProduct) ascending = false; // kernel to slice first (min) k cols and copy into batched merge buffer select_k(batch_dists, batch_indices, batch_rows, batch_cols, out_dists, out_indices, ascending, n_neighbors, stream); } - raft::distance::DistanceType get_pw_metric() { - raft::distance::DistanceType pw_metric; - switch (metric) { - case ML::MetricType::METRIC_INNER_PRODUCT: - pw_metric = raft::distance::DistanceType::InnerProduct; - break; - case ML::MetricType::METRIC_L2: - pw_metric = raft::distance::DistanceType::L2Expanded; - break; - case ML::MetricType::METRIC_L1: - pw_metric = raft::distance::DistanceType::L1; - break; - case ML::MetricType::METRIC_Canberra: - pw_metric = raft::distance::DistanceType::Canberra; - break; - case ML::MetricType::METRIC_Linf: - pw_metric = raft::distance::DistanceType::Linf; - break; - case ML::MetricType::METRIC_Lp: - pw_metric = raft::distance::DistanceType::LpUnexpanded; - break; - case ML::MetricType::METRIC_Jaccard: - pw_metric = raft::distance::DistanceType::JaccardExpanded; - break; - case ML::MetricType::METRIC_Cosine: - pw_metric = raft::distance::DistanceType::CosineExpanded; - break; - case ML::MetricType::METRIC_Hellinger: - pw_metric = raft::distance::DistanceType::HellingerExpanded; - break; - default: - THROW("MetricType not supported: %d", metric); - } - - return pw_metric; - } - void compute_distances(csr_batcher_t &idx_batcher, csr_batcher_t &query_batcher, size_t idx_batch_nnz, size_t query_batch_nnz, @@ -425,8 +369,12 @@ class sparse_knn_t { dist_config.allocator = allocator; dist_config.stream = stream; - raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, - get_pw_metric(), metricArg); + if (raft::sparse::distance::supportedDistance.find(metric) == + raft::sparse::distance::supportedDistance.end()) + THROW("DistanceType not supported: %d", metric); + + raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, metric, + metricArg); } const value_idx *idxIndptr, *idxIndices, *queryIndptr, *queryIndices; @@ -436,12 +384,10 @@ class sparse_knn_t { size_t idxNNZ, queryNNZ, batch_size_index, batch_size_query; - ML::MetricType metric; + raft::distance::DistanceType metric; float metricArg; - bool expanded_form; - int n_idx_rows, n_idx_cols, n_query_rows, n_query_cols, k; cusparseHandle_t cusparseHandle; @@ -475,7 +421,6 @@ class sparse_knn_t { * @param[in] batch_size_query maximum number of rows to use from query matrix per batch * @param[in] metric distance metric/measure to use * @param[in] metricArg potential argument for metric (currently unused) - * @param[in] expanded_form whether or not Lp variants should be reduced by the pth-root */ template void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, @@ -489,13 +434,14 @@ void brute_force_knn(const value_idx *idxIndptr, const value_idx *idxIndices, cudaStream_t stream, size_t batch_size_index = 2 << 14, // approx 1M size_t batch_size_query = 2 << 14, - ML::MetricType metric = ML::MetricType::METRIC_L2, - float metricArg = 0, bool expanded_form = false) { + raft::distance::DistanceType metric = + raft::distance::DistanceType::L2Expanded, + float metricArg = 0) { sparse_knn_t( idxIndptr, idxIndices, idxData, idxNNZ, n_idx_rows, n_idx_cols, queryIndptr, queryIndices, queryData, queryNNZ, n_query_rows, n_query_cols, output_indices, output_dists, k, cusparseHandle, allocator, stream, - batch_size_index, batch_size_query, metric, metricArg, expanded_form) + batch_size_index, batch_size_query, metric, metricArg) .run(); } diff --git a/cpp/test/prims/knn.cu b/cpp/test/prims/knn.cu index 5c02a80867..4b3e4ee89e 100644 --- a/cpp/test/prims/knn.cu +++ b/cpp/test/prims/knn.cu @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020, NVIDIA CORPORATION. + * Copyright (c) 2019-2021, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,6 +16,7 @@ #include #include +#include #include #include #include @@ -69,7 +70,8 @@ class KNNTest : public ::testing::Test { CUDA_CHECK(cudaStreamCreate(&stream)); brute_force_knn(input_vec, sizes_vec, d, d_train_inputs, n, d_pred_I, - d_pred_D, n, alloc, stream); + d_pred_D, n, alloc, stream, nullptr, 0, true, true, nullptr, + raft::distance::DistanceType::L2SqrtExpanded); CUDA_CHECK(cudaStreamDestroy(stream)); } diff --git a/cpp/test/prims/sparse/knn.cu b/cpp/test/prims/sparse/knn.cu index 63a2735cb2..09cf39fa9d 100644 --- a/cpp/test/prims/sparse/knn.cu +++ b/cpp/test/prims/sparse/knn.cu @@ -17,6 +17,8 @@ #include #include +#include +#include #include #include @@ -48,7 +50,8 @@ struct SparseKNNInputs { int batch_size_index = 2; int batch_size_query = 2; - ML::MetricType metric = ML::MetricType::METRIC_L2; + raft::distance::DistanceType metric = + raft::distance::DistanceType::L2SqrtExpanded; }; template @@ -164,7 +167,7 @@ const std::vector> inputs_i32_f = { 2, 2, 2, - ML::MetricType::METRIC_L2}}; + raft::distance::DistanceType::L2SqrtExpanded}}; typedef SparseKNNTest KNNTestF; TEST_P(KNNTestF, Result) { compare(); } INSTANTIATE_TEST_CASE_P(SparseKNNTest, KNNTestF, diff --git a/cpp/test/sg/umap_test.cu b/cpp/test/sg/umap_test.cu index 491c0138c2..6ec1b9f13c 100644 --- a/cpp/test/sg/umap_test.cu +++ b/cpp/test/sg/umap_test.cu @@ -82,7 +82,7 @@ class UMAPTest : public ::testing::Test { CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - xformed_score = trustworthiness_score( + xformed_score = trustworthiness_score( handle, X_d.data(), xformed.data(), n_samples, n_features, umap_params->n_components, umap_params->n_neighbors); } @@ -117,7 +117,7 @@ class UMAPTest : public ::testing::Test { CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - fit_score = trustworthiness_score( + fit_score = trustworthiness_score( handle, X_d.data(), embeddings.data(), n_samples, n_features, umap_params->n_components, umap_params->n_neighbors); } @@ -154,7 +154,7 @@ class UMAPTest : public ::testing::Test { CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - supervised_score = trustworthiness_score( + supervised_score = trustworthiness_score( handle, X_d.data(), embeddings.data(), n_samples, n_features, umap_params->n_components, umap_params->n_neighbors); } @@ -213,7 +213,7 @@ class UMAPTest : public ::testing::Test { CUDA_CHECK(cudaStreamSynchronize(handle.get_stream())); - fit_with_knn_score = trustworthiness_score( + fit_with_knn_score = trustworthiness_score( handle, X_d.data(), embeddings.data(), n_samples, n_features, umap_params->n_components, umap_params->n_neighbors); } diff --git a/python/cuml/metrics/distance_type.pxd b/python/cuml/metrics/distance_type.pxd index 95ae282581..4a61021628 100644 --- a/python/cuml/metrics/distance_type.pxd +++ b/python/cuml/metrics/distance_type.pxd @@ -1,3 +1,19 @@ +# +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + cdef extern from "raft/linalg/distance_type.h" namespace "raft::distance": ctypedef enum DistanceType: @@ -7,3 +23,13 @@ cdef extern from "raft/linalg/distance_type.h" namespace "raft::distance": L1 "raft::distance::DistanceType::L1" L2Unexpanded "raft::distance::DistanceType::L2Unexpanded" L2SqrtUnexpanded "raft::distance::DistanceType::L2SqrtUnexpanded" + InnerProduct "raft::distance::DistanceType::InnerProduct" + Linf "raft::distance::DistanceType::Linf" + Canberra "raft::distance::DistanceType::Canberra" + LpUnexpanded "raft::distance::DistanceType::LpUnexpanded" + CorrelationExpanded "raft::distance::DistanceType::CorrelationExpanded" + JaccardExpanded "raft::distance::DistanceType::JaccardExpanded" + HellingerExpanded "raft::distance::DistanceType::HellingerExpanded" + Haversine "raft::distance::DistanceType::Haversine" + BrayCurtis "raft::distance::DistanceType::BrayCurtis" + JensenShannon "raft::distance::DistanceType::JensenShannon" diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index e2ef991e0c..2146aefeab 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -38,6 +38,7 @@ from cuml.common.input_utils import input_to_cupy_array from cuml.common import input_to_cuml_array from cuml.common.sparse_utils import is_sparse from cuml.common.sparse_utils import is_dense +from cuml.metrics.distance_type cimport DistanceType from cuml.neighbors.ann cimport * from cuml.raft.common.handle cimport handle_t @@ -64,22 +65,6 @@ if has_scipy(): cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": - enum MetricType: - METRIC_INNER_PRODUCT = 0, - METRIC_L2, - METRIC_L1, - METRIC_Linf, - METRIC_Lp, - - METRIC_Canberra = 20, - METRIC_BrayCurtis, - METRIC_JensenShannon, - - METRIC_Cosine = 100, - METRIC_Correlation, - METRIC_Jaccard, - METRIC_Hellinger - cdef cppclass knnIndex: pass @@ -95,9 +80,8 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": int k, bool rowMajorIndex, bool rowMajorQuery, - MetricType metric, - float metric_arg, - bool expanded + DistanceType metric, + float metric_arg ) except + void approx_knn_build_index( @@ -105,7 +89,7 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": knnIndex* index, knnIndexParam* params, int D, - MetricType metric, + DistanceType metric, float metricArg, float *search_items, int n @@ -139,9 +123,8 @@ cdef extern from "cuml/neighbors/knn_sparse.hpp" namespace "ML::Sparse": int k, size_t batch_size_index, size_t batch_size_query, - MetricType metric, - float metricArg, - bool expanded_form) except + + DistanceType metric, + float metricArg) except + class NearestNeighbors(Base, @@ -386,13 +369,13 @@ class NearestNeighbors(Base, algo_params = \ build_algo_params(self.algorithm, self.algo_params, additional_info) - metric, expanded = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.metric) approx_knn_build_index(handle_[0], knn_index, algo_params, n_cols, - metric, + metric, self.p, self.X_m.ptr, self.n_rows) @@ -412,41 +395,36 @@ class NearestNeighbors(Base, @staticmethod def _build_metric_type(metric): - - expanded = False - if metric == "euclidean" or metric == "l2": - m = MetricType.METRIC_L2 + m = DistanceType.L2SqrtExpanded elif metric == "sqeuclidean": - m = MetricType.METRIC_L2 - expanded = True - elif metric == "cityblock" or metric == "l1"\ - or metric == "manhattan" or metric == 'taxicab': - m = MetricType.METRIC_L1 + m = DistanceType.L2Expanded + elif metric in ["cityblock", "l1", "manhattan", 'taxicab']: + m = DistanceType.L1 elif metric == "braycurtis": - m = MetricType.METRIC_BrayCurtis + m = DistanceType.BrayCurtis elif metric == "canberra": - m = MetricType.METRIC_Canberra + m = DistanceType.Canberra elif metric == "minkowski" or metric == "lp": - m = MetricType.METRIC_Lp + m = DistanceType.LpUnexpanded elif metric == "chebyshev" or metric == "linf": - m = MetricType.METRIC_Linf + m = DistanceType.Linf elif metric == "jensenshannon": - m = MetricType.METRIC_JensenShannon + m = DistanceType.JensenShannon elif metric == "cosine": - m = MetricType.METRIC_Cosine + m = DistanceType.CosineExpanded elif metric == "correlation": - m = MetricType.METRIC_Correlation + m = DistanceType.CorrelationExpanded elif metric == "inner_product": - m = MetricType.METRIC_INNER_PRODUCT + m = DistanceType.InnerProduct elif metric == "jaccard": - m = MetricType.METRIC_Jaccard + m = DistanceType.JaccardExpanded elif metric == "hellinger": - m = MetricType.METRIC_Hellinger + m = DistanceType.HellingerExpanded else: raise ValueError("Metric %s is not supported" % metric) - return m, expanded + return m @insert_into_docstring(parameters=[('dense', '(n_samples, n_features)')], return_values=[('dense', '(n_samples, n_features)'), @@ -601,25 +579,24 @@ class NearestNeighbors(Base, if _output_type is not None else self._get_output_type(X) if two_pass_precision: - metric, expanded = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.metric) metric_is_l2_based = ( - metric == MetricType.METRIC_L2 or - (metric == MetricType.METRIC_Lp and self.p == 2) + metric == DistanceType.L2SqrtExpanded or + metric == DistanceType.L2Expanded or + (metric == DistanceType.LpUnexpanded and self.p == 2) ) # FAISS employs imprecise distance algorithm only for L2-based - # metrics + # expanded metrics. This code correct numerical instabilities + # that could arise. if metric_is_l2_based: X = input_to_cupy_array(X).array I_cparr = I_ndarr.to_output('cupy') self_diff = X[I_cparr] - X[:, cp.newaxis, :] - if expanded: - precise_distances = cp.sum( - self_diff * self_diff, axis=2 - ) - else: - precise_distances = cp.linalg.norm(self_diff, axis=2) + precise_distances = cp.sum( + self_diff * self_diff, axis=2 + ) correct_order = cp.argsort(precise_distances, axis=1) @@ -652,7 +629,7 @@ class NearestNeighbors(Base, raise ValueError("A NearestNeighbors model trained on dense " "data requires dense input to kneighbors()") - metric, expanded = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.metric) X_m, N, _, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, @@ -689,10 +666,9 @@ class NearestNeighbors(Base, n_neighbors, True, True, - metric, + metric, # minkowski order is currently the only metric argument. - self.p, - expanded + self.p ) else: knn_index = self.knn_index @@ -726,7 +702,7 @@ class NearestNeighbors(Base, X_m = SparseCumlArray(X, convert_to_dtype=cp.float32, convert_format=False) - metric, expanded = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.metric) cdef uintptr_t idx_indptr = self.X_m.indptr.ptr cdef uintptr_t idx_indices = self.X_m.indices.ptr @@ -766,9 +742,8 @@ class NearestNeighbors(Base, n_neighbors, batch_size_index, batch_size_query, - metric, - self.p, - expanded) + metric, + self.p) return D_ndarr, I_ndarr diff --git a/python/cuml/test/test_metrics.py b/python/cuml/test/test_metrics.py index 5533189231..adc17caa6e 100644 --- a/python/cuml/test/test_metrics.py +++ b/python/cuml/test/test_metrics.py @@ -1086,6 +1086,8 @@ def test_pairwise_distances_output_types(input_type, output_type, use_global): assert isinstance(S, cp.core.core.ndarray) +@pytest.mark.xfail(reason='Temporarily disabling this test. ' + 'See rapidsai/cuml#3569') @pytest.mark.parametrize("nrows, ncols, n_info", [ unit_param(30, 10, 7),