Skip to content

Commit

Permalink
Replace ML::MetricType with raft::distance::DistanceType (#3389)
Browse files Browse the repository at this point in the history
Closes #3319.
This PR will replace the distance type from ML::MetricType to raft::distance::DistanceType.

Since Raft DistanceType makes the distinction between the expanded and non-expanded distances in the name, I changed the C++ API to remove the boolean parameter `expanded` which becomes useless.

Authors:
  - Micka (@lowener)

Approvers:
  - Corey J. Nolet (@cjnolet)

URL: #3389
  • Loading branch information
lowener authored Mar 2, 2021
1 parent 6dddae4 commit 9fa6e17
Show file tree
Hide file tree
Showing 20 changed files with 247 additions and 203 deletions.
30 changes: 6 additions & 24 deletions cpp/include/cuml/neighbors/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,11 @@

#include <faiss/gpu/GpuIndex.h>
#include <faiss/gpu/StandardGpuResources.h>
#include <raft/linalg/distance_type.h>
#include <cuml/common/logger.hpp>
#include <cuml/cuml.hpp>

namespace ML {

enum MetricType {
METRIC_INNER_PRODUCT = 0,
METRIC_L2,
METRIC_L1,
METRIC_Linf,
METRIC_Lp,

METRIC_Canberra = 20,
METRIC_BrayCurtis,
METRIC_JensenShannon,

METRIC_Cosine = 100,
METRIC_Correlation,
METRIC_Jaccard,
METRIC_Hellinger
};

struct knnIndex {
faiss::gpu::StandardGpuResources *gpu_res;
faiss::gpu::GpuIndex *index;
Expand Down Expand Up @@ -102,20 +85,19 @@ struct IVFSQParam : IVFParam {
* default
* @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
* @param[in] expanded should lp-based distances be returned in their expanded
* form (e.g., without raising to the 1/p power).
*/
void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
std::vector<int> &sizes, int D, float *search_items, int n,
int64_t *res_I, float *res_D, int k,
bool rowMajorIndex = false, bool rowMajorQuery = false,
MetricType metric = MetricType::METRIC_L2,
float metric_arg = 2.0f, bool expanded = false);
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
float metric_arg = 2.0f);

void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params, int D,
ML::MetricType metric, float metricArg,
float *index_items, int n);
raft::distance::DistanceType metric,
float metricArg, float *index_items, int n);

void approx_knn_search(ML::knnIndex *index, int n, const float *x, int k,
float *distances, int64_t *labels);
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuml/neighbors/knn_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ extern "C" {
* @param[in] rowMajorIndex is the index array in row major layout?
* @param[in] rowMajorQuery is the query array in row major layout?
* @param[in] metric_type the type of distance metric to use. This corresponds
* to the value in the ML::MetricType enum. Default is
* Euclidean (L2).
* to the value in the raft::distance::DistanceType enum.
* Default is Euclidean (L2).
* @param[in] metric_arg the value of `p` for Minkowski (l-p) distances. This
* is ignored if the metric_type is not Minkowski.
* @param[in] expanded should lp-based distances be returned in their expanded
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/cuml/neighbors/knn_sparse.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <cusparse_v2.h>

#include <raft/linalg/distance_type.h>
#include <cuml/neighbors/knn.hpp>

namespace ML {
Expand All @@ -36,7 +37,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
float *output_dists, int k,
size_t batch_size_index = DEFAULT_BATCH_SIZE,
size_t batch_size_query = DEFAULT_BATCH_SIZE,
ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false);
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
float metricArg = 0);
}; // end namespace Sparse
}; // end namespace ML
10 changes: 5 additions & 5 deletions cpp/src/knn/knn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace ML {
void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
std::vector<int> &sizes, int D, float *search_items, int n,
int64_t *res_I, float *res_D, int k, bool rowMajorIndex,
bool rowMajorQuery, MetricType metric, float metric_arg,
bool expanded) {
bool rowMajorQuery, raft::distance::DistanceType metric,
float metric_arg) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors must be the same size");

Expand All @@ -44,13 +44,13 @@ void brute_force_knn(raft::handle_t &handle, std::vector<float *> &input,
input, sizes, D, search_items, n, res_I, res_D, k,
handle.get_device_allocator(), handle.get_stream(), int_streams.data(),
handle.get_num_internal_streams(), rowMajorIndex, rowMajorQuery, nullptr,
metric, metric_arg, expanded);
metric, metric_arg);
}

void approx_knn_build_index(raft::handle_t &handle, ML::knnIndex *index,
ML::knnIndexParam *params, int D,
ML::MetricType metric, float metricArg,
float *index_items, int n) {
raft::distance::DistanceType metric,
float metricArg, float *index_items, int n) {
MLCommon::Selection::approx_knn_build_index(
index, params, D, metric, metricArg, index_items, n, handle.get_stream());
}
Expand Down
4 changes: 3 additions & 1 deletion cpp/src/knn/knn_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes,
cumlError_t status;
raft::handle_t *handle_ptr;
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle);
raft::distance::DistanceType metric_distance_type =
static_cast<raft::distance::DistanceType>(metric_type);

std::vector<cudaStream_t> int_streams = handle_ptr->get_internal_streams();

Expand All @@ -71,7 +73,7 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes,
try {
ML::brute_force_knn(*handle_ptr, input_vec, sizes_vec, D, search_items, n,
res_I, res_D, k, rowMajorIndex, rowMajorQuery,
(ML::MetricType)metric_type, metric_arg, expanded);
metric_distance_type, metric_arg);
} catch (...) {
status = CUML_ERROR_UNKNOWN;
}
Expand Down
7 changes: 3 additions & 4 deletions cpp/src/knn/knn_sparse.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
int n_query_rows, int n_query_cols, int *output_indices,
float *output_dists, int k,
size_t batch_size_index, // approx 1M
size_t batch_size_query, ML::MetricType metric,
float metricArg, bool expanded_form) {
size_t batch_size_query,
raft::distance::DistanceType metric, float metricArg) {
auto d_alloc = handle.get_device_allocator();
cusparseHandle_t cusparse_handle = handle.get_cusparse_handle();
cudaStream_t stream = handle.get_stream();
Expand All @@ -42,8 +42,7 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr,
idx_indptr, idx_indices, idx_data, idx_nnz, n_idx_rows, n_idx_cols,
query_indptr, query_indices, query_data, query_nnz, n_query_rows,
n_query_cols, output_indices, output_dists, k, cusparse_handle, d_alloc,
stream, batch_size_index, batch_size_query, metric, metricArg,
expanded_form);
stream, batch_size_index, batch_size_query, metric, metricArg);
}
}; // namespace Sparse
}; // namespace ML
4 changes: 2 additions & 2 deletions cpp/src/tsne/distances.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2020, NVIDIA CORPORATION.
* Copyright (c) 2019-2021, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -89,7 +89,7 @@ void get_distances(const raft::handle_t &handle,
k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors,
handle.get_cusparse_handle(), handle.get_device_allocator(), stream,
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE,
ML::MetricType::METRIC_L2);
raft::distance::DistanceType::L2Expanded);
}

// sparse, int64
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <cuml/manifold/umapparams.h>
#include <raft/linalg/distance_type.h>
#include <cuml/manifold/common.hpp>
#include <cuml/neighbors/knn_sparse.hpp>
#include <iostream>
Expand Down Expand Up @@ -91,7 +92,7 @@ void launcher(const raft::handle_t &handle,
inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors,
handle.get_cusparse_handle(), d_alloc, stream,
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE,
ML::MetricType::METRIC_L2);
raft::distance::DistanceType::L2Expanded);
}

template <>
Expand Down
61 changes: 45 additions & 16 deletions cpp/src_prims/selection/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <thrust/device_vector.h>
#include <thrust/iterator/transform_iterator.h>

#include <raft/linalg/distance_type.h>
#include "processing.cuh"

#include <cuml/common/cuml_allocator.hpp>
Expand Down Expand Up @@ -186,14 +187,37 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK,
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations);
}

inline faiss::MetricType build_faiss_metric(ML::MetricType metric) {
inline faiss::MetricType build_faiss_metric(
raft::distance::DistanceType metric) {
switch (metric) {
case ML::MetricType::METRIC_Cosine:
case raft::distance::DistanceType::CosineExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case ML::MetricType::METRIC_Correlation:
case raft::distance::DistanceType::CorrelationExpanded:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::L2Expanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2Unexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtExpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L2SqrtUnexpanded:
return faiss::MetricType::METRIC_L2;
case raft::distance::DistanceType::L1:
return faiss::MetricType::METRIC_L1;
case raft::distance::DistanceType::InnerProduct:
return faiss::MetricType::METRIC_INNER_PRODUCT;
case raft::distance::DistanceType::LpUnexpanded:
return faiss::MetricType::METRIC_Lp;
case raft::distance::DistanceType::Linf:
return faiss::MetricType::METRIC_Linf;
case raft::distance::DistanceType::Canberra:
return faiss::MetricType::METRIC_Canberra;
case raft::distance::DistanceType::BrayCurtis:
return faiss::MetricType::METRIC_BrayCurtis;
case raft::distance::DistanceType::JensenShannon:
return faiss::MetricType::METRIC_JensenShannon;
default:
return (faiss::MetricType)metric;
THROW("MetricType not supported: %d", metric);
}
}

Expand All @@ -219,7 +243,8 @@ inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype(

template <typename IntType = int>
void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params,
IntType D, ML::MetricType metric,
IntType D,
raft::distance::DistanceType metric,
IntType n) {
faiss::gpu::GpuIndexIVFFlatConfig config;
config.device = index->device;
Expand All @@ -232,7 +257,9 @@ void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params,

template <typename IntType = int>
void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params,
IntType D, ML::MetricType metric, IntType n) {
IntType D,
raft::distance::DistanceType metric,
IntType n) {
faiss::gpu::GpuIndexIVFPQConfig config;
config.device = index->device;
config.usePrecomputedTables = params->usePrecomputedTables;
Expand All @@ -246,7 +273,9 @@ void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params,

template <typename IntType = int>
void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params,
IntType D, ML::MetricType metric, IntType n) {
IntType D,
raft::distance::DistanceType metric,
IntType n) {
faiss::gpu::GpuIndexIVFScalarQuantizerConfig config;
config.device = index->device;
faiss::MetricType faiss_metric = build_faiss_metric(metric);
Expand All @@ -262,8 +291,8 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params,

template <typename IntType = int>
void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params,
IntType D, ML::MetricType metric, float metricArg,
float *index_items, IntType n,
IntType D, raft::distance::DistanceType metric,
float metricArg, float *index_items, IntType n,
cudaStream_t userStream) {
int device;
CUDA_CHECK(cudaGetDevice(&device));
Expand Down Expand Up @@ -328,9 +357,8 @@ void approx_knn_search(ML::knnIndex *index, IntType n, const float *x,
* @param[in] rowMajorQuery are the query array in row-major layout?
* @param[in] translations translation ids for indices when index rows represent
* non-contiguous partitions
* @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean)
* @param[in] metric corresponds to the raft::distance::DistanceType enum (default is L2Expanded)
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm
* @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root
*/
template <typename IntType = int>
void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes,
Expand All @@ -342,8 +370,9 @@ void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes,
int n_int_streams = 0, bool rowMajorIndex = true,
bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
ML::MetricType metric = ML::MetricType::METRIC_L2,
float metricArg = 0, bool expanded_form = false) {
raft::distance::DistanceType metric =
raft::distance::DistanceType::L2Expanded,
float metricArg = 0) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors should be the same size");

Expand Down Expand Up @@ -452,9 +481,9 @@ void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes,
}

// Perform necessary post-processing
if ((m == faiss::MetricType::METRIC_L2 ||
m == faiss::MetricType::METRIC_Lp) &&
!expanded_form) {
if (metric == raft::distance::DistanceType::L2SqrtExpanded ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::LpUnexpanded) {
/**
* post-processing
*/
Expand Down
6 changes: 3 additions & 3 deletions cpp/src_prims/selection/processing.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -162,17 +162,17 @@ class DefaultMetricProcessor : public MetricProcessor<math_t> {

template <typename math_t>
inline std::unique_ptr<MetricProcessor<math_t>> create_processor(
ML::MetricType metric, int n, int D, int k, bool rowMajorQuery,
raft::distance::DistanceType metric, int n, int D, int k, bool rowMajorQuery,
cudaStream_t userStream, std::shared_ptr<deviceAllocator> allocator) {
MetricProcessor<math_t> *mp = nullptr;

switch (metric) {
case ML::MetricType::METRIC_Cosine:
case raft::distance::DistanceType::CosineExpanded:
mp = new CosineMetricProcessor<math_t>(n, D, k, rowMajorQuery, userStream,
allocator);
break;

case ML::MetricType::METRIC_Correlation:
case raft::distance::DistanceType::CorrelationExpanded:
mp = new CorrelationMetricProcessor<math_t>(n, D, k, rowMajorQuery,
userStream, allocator);
break;
Expand Down
23 changes: 23 additions & 0 deletions cpp/src_prims/sparse/distance/distance.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <raft/cudart_utils.h>
#include <unordered_set>

#include <raft/linalg/distance_type.h>
#include <raft/sparse/cusparse_wrappers.h>
Expand All @@ -42,6 +43,20 @@ namespace raft {
namespace sparse {
namespace distance {

static const std::unordered_set<raft::distance::DistanceType> supportedDistance{
raft::distance::DistanceType::L2Expanded,
raft::distance::DistanceType::L2Unexpanded,
raft::distance::DistanceType::L2SqrtExpanded,
raft::distance::DistanceType::L2SqrtUnexpanded,
raft::distance::DistanceType::InnerProduct,
raft::distance::DistanceType::L1,
raft::distance::DistanceType::Canberra,
raft::distance::DistanceType::Linf,
raft::distance::DistanceType::LpUnexpanded,
raft::distance::DistanceType::JaccardExpanded,
raft::distance::DistanceType::CosineExpanded,
raft::distance::DistanceType::HellingerExpanded};

/**
* Compute pairwise distances between A and B, using the provided
* input configuration and distance function.
Expand All @@ -60,12 +75,20 @@ void pairwiseDistance(value_t *out,
case raft::distance::DistanceType::L2Expanded:
l2_expanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2SqrtExpanded:
l2_sqrt_expanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::InnerProduct:
ip_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2Unexpanded:
l2_unexpanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
case raft::distance::DistanceType::L2SqrtUnexpanded:
l2_sqrt_unexpanded_distances_t<value_idx, value_t>(input_config)
.compute(out);
break;
case raft::distance::DistanceType::L1:
l1_unexpanded_distances_t<value_idx, value_t>(input_config).compute(out);
break;
Expand Down
Loading

0 comments on commit 9fa6e17

Please sign in to comment.