-
Notifications
You must be signed in to change notification settings - Fork 540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace ML::MetricType with raft::distance::DistanceType #3389
Changes from 31 commits
962357f
77decf2
9e33273
a7e9093
134d251
234898a
5261873
050f4d0
c96a8c4
fcfc8b7
fc3cc75
50c29c6
09200e0
af05da6
380883c
2441139
a53ed06
521b983
44dcb80
fa136e8
b3aa41e
3eb8ccb
6f724b8
241d913
ee20152
d40a787
01e09af
debdb3a
5a694aa
e5f25f9
cc95e28
11f668e
a419589
0ff0bbe
819991d
6259a0b
9c1b466
03fd464
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,6 +20,7 @@ | |
|
||
#include <cusparse_v2.h> | ||
|
||
#include <raft/linalg/distance_type.h> | ||
#include <cuml/neighbors/knn.hpp> | ||
|
||
namespace ML { | ||
|
@@ -36,7 +37,8 @@ void brute_force_knn(raft::handle_t &handle, const int *idx_indptr, | |
float *output_dists, int k, | ||
size_t batch_size_index = DEFAULT_BATCH_SIZE, | ||
size_t batch_size_query = DEFAULT_BATCH_SIZE, | ||
ML::MetricType metric = ML::MetricType::METRIC_L2, | ||
float metricArg = 0, bool expanded_form = false); | ||
raft::distance::DistanceType metric = | ||
raft::distance::DistanceType::L2Unexpanded, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here, I would leave this in expanded form. |
||
float metricArg = 0); | ||
}; // end namespace Sparse | ||
}; // end namespace ML |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -71,7 +71,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes, | |
try { | ||
ML::brute_force_knn(*handle_ptr, input_vec, sizes_vec, D, search_items, n, | ||
res_I, res_D, k, rowMajorIndex, rowMajorQuery, | ||
(ML::MetricType)metric_type, metric_arg, expanded); | ||
(raft::distance::DistanceType)metric_type, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to keep making this conversion explicit? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we're using the same enum type everywhere now, I think this conversion can be removed. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This conversion from int to |
||
metric_arg); | ||
} catch (...) { | ||
status = CUML_ERROR_UNKNOWN; | ||
} | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
/* | ||
* Copyright (c) 2019-2020, NVIDIA CORPORATION. | ||
* Copyright (c) 2019-2021, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
|
@@ -89,7 +89,7 @@ void get_distances(const raft::handle_t &handle, | |
k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, | ||
handle.get_cusparse_handle(), handle.get_device_allocator(), stream, | ||
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, | ||
ML::MetricType::METRIC_L2); | ||
raft::distance::DistanceType::L2Expanded); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't reach the users so I think it's okay not to use the sqrt here. I think the expanded form is also good to use here for speed. |
||
} | ||
|
||
// sparse, int64 | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,7 @@ | |
#pragma once | ||
|
||
#include <cuml/manifold/umapparams.h> | ||
#include <raft/linalg/distance_type.h> | ||
#include <cuml/manifold/common.hpp> | ||
#include <cuml/neighbors/knn_sparse.hpp> | ||
#include <iostream> | ||
|
@@ -91,7 +92,7 @@ void launcher(const raft::handle_t &handle, | |
inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors, | ||
handle.get_cusparse_handle(), d_alloc, stream, | ||
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, | ||
ML::MetricType::METRIC_L2); | ||
raft::distance::DistanceType::L2Expanded); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would leave all of these in expanded form. The unexpanded is more stable under some conditions but in general it's a better (and faster) starting point. |
||
} | ||
|
||
template <> | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -37,6 +37,7 @@ | |
#include <thrust/device_vector.h> | ||
#include <thrust/iterator/transform_iterator.h> | ||
|
||
#include <raft/linalg/distance_type.h> | ||
#include "processing.cuh" | ||
|
||
#include <cuml/common/cuml_allocator.hpp> | ||
|
@@ -186,14 +187,37 @@ inline void knn_merge_parts(value_t *inK, value_idx *inV, value_t *outK, | |
inK, inV, outK, outV, n_samples, n_parts, k, stream, translations); | ||
} | ||
|
||
inline faiss::MetricType build_faiss_metric(ML::MetricType metric) { | ||
inline faiss::MetricType build_faiss_metric( | ||
raft::distance::DistanceType metric) { | ||
switch (metric) { | ||
case ML::MetricType::METRIC_Cosine: | ||
case raft::distance::DistanceType::CosineExpanded: | ||
return faiss::MetricType::METRIC_INNER_PRODUCT; | ||
case ML::MetricType::METRIC_Correlation: | ||
case raft::distance::DistanceType::CorrelationExpanded: | ||
return faiss::MetricType::METRIC_INNER_PRODUCT; | ||
case raft::distance::DistanceType::L2Expanded: | ||
return faiss::MetricType::METRIC_L2; | ||
case raft::distance::DistanceType::L2Unexpanded: | ||
return faiss::MetricType::METRIC_L2; | ||
case raft::distance::DistanceType::L2SqrtExpanded: | ||
return faiss::MetricType::METRIC_L2; | ||
case raft::distance::DistanceType::L2SqrtUnexpanded: | ||
return faiss::MetricType::METRIC_L2; | ||
case raft::distance::DistanceType::L1: | ||
return faiss::MetricType::METRIC_L1; | ||
case raft::distance::DistanceType::InnerProduct: | ||
return faiss::MetricType::METRIC_INNER_PRODUCT; | ||
case raft::distance::DistanceType::LpUnexpanded: | ||
return faiss::MetricType::METRIC_Lp; | ||
case raft::distance::DistanceType::Linf: | ||
return faiss::MetricType::METRIC_Linf; | ||
case raft::distance::DistanceType::Canberra: | ||
return faiss::MetricType::METRIC_Canberra; | ||
case raft::distance::DistanceType::BrayCurtis: | ||
return faiss::MetricType::METRIC_BrayCurtis; | ||
case raft::distance::DistanceType::JensenShannon: | ||
return faiss::MetricType::METRIC_JensenShannon; | ||
default: | ||
return (faiss::MetricType)metric; | ||
THROW("MetricType not supported: %d", metric); | ||
} | ||
Comment on lines
+190
to
221
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thank you for this. It's a lot better and avoids future issues (in case enums are modified). However for now, no post-processing is applied to distances produced by approximate nearest neighbors. It means that only the raw FAISS metrics should be available. It is actually an issue that I should probably work on next. |
||
} | ||
|
||
|
@@ -219,7 +243,8 @@ inline faiss::ScalarQuantizer::QuantizerType build_faiss_qtype( | |
|
||
template <typename IntType = int> | ||
void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params, | ||
IntType D, ML::MetricType metric, | ||
IntType D, | ||
raft::distance::DistanceType metric, | ||
IntType n) { | ||
faiss::gpu::GpuIndexIVFFlatConfig config; | ||
config.device = index->device; | ||
|
@@ -232,7 +257,9 @@ void approx_knn_ivfflat_build_index(ML::knnIndex *index, ML::IVFParam *params, | |
|
||
template <typename IntType = int> | ||
void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params, | ||
IntType D, ML::MetricType metric, IntType n) { | ||
IntType D, | ||
raft::distance::DistanceType metric, | ||
IntType n) { | ||
faiss::gpu::GpuIndexIVFPQConfig config; | ||
config.device = index->device; | ||
config.usePrecomputedTables = params->usePrecomputedTables; | ||
|
@@ -246,7 +273,9 @@ void approx_knn_ivfpq_build_index(ML::knnIndex *index, ML::IVFPQParam *params, | |
|
||
template <typename IntType = int> | ||
void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params, | ||
IntType D, ML::MetricType metric, IntType n) { | ||
IntType D, | ||
raft::distance::DistanceType metric, | ||
IntType n) { | ||
faiss::gpu::GpuIndexIVFScalarQuantizerConfig config; | ||
config.device = index->device; | ||
faiss::MetricType faiss_metric = build_faiss_metric(metric); | ||
|
@@ -262,8 +291,8 @@ void approx_knn_ivfsq_build_index(ML::knnIndex *index, ML::IVFSQParam *params, | |
|
||
template <typename IntType = int> | ||
void approx_knn_build_index(ML::knnIndex *index, ML::knnIndexParam *params, | ||
IntType D, ML::MetricType metric, float metricArg, | ||
float *index_items, IntType n, | ||
IntType D, raft::distance::DistanceType metric, | ||
float metricArg, float *index_items, IntType n, | ||
cudaStream_t userStream) { | ||
int device; | ||
CUDA_CHECK(cudaGetDevice(&device)); | ||
|
@@ -330,7 +359,6 @@ void approx_knn_search(ML::knnIndex *index, IntType n, const float *x, | |
* non-contiguous partitions | ||
* @param[in] metric corresponds to the FAISS::metricType enum (default is euclidean) | ||
* @param[in] metricArg metric argument to use. Corresponds to the p arg for lp norm | ||
* @param[in] expanded_form whether or not lp variants should be reduced w/ lp-root | ||
*/ | ||
template <typename IntType = int> | ||
void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes, | ||
|
@@ -342,8 +370,9 @@ void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes, | |
int n_int_streams = 0, bool rowMajorIndex = true, | ||
bool rowMajorQuery = true, | ||
std::vector<int64_t> *translations = nullptr, | ||
ML::MetricType metric = ML::MetricType::METRIC_L2, | ||
float metricArg = 0, bool expanded_form = false) { | ||
raft::distance::DistanceType metric = | ||
raft::distance::DistanceType::L2Expanded, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same, |
||
float metricArg = 0) { | ||
ASSERT(input.size() == sizes.size(), | ||
"input and sizes vectors should be the same size"); | ||
|
||
|
@@ -452,9 +481,9 @@ void brute_force_knn(std::vector<float *> &input, std::vector<int> &sizes, | |
} | ||
|
||
// Perform necessary post-processing | ||
if ((m == faiss::MetricType::METRIC_L2 || | ||
m == faiss::MetricType::METRIC_Lp) && | ||
!expanded_form) { | ||
if (metric == raft::distance::DistanceType::L2SqrtExpanded || | ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded || | ||
metric == raft::distance::DistanceType::LpUnexpanded) { | ||
Comment on lines
+484
to
+486
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Might be wrong, but I think that in this case, only unexpanded forms (that need post-procesing) : There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. FAISS only supports the expanded form but I believe we're converting both the |
||
/** | ||
* post-processing | ||
*/ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we may want
L2SqrtUnexpanded
here to have the euclidean distance as default. At least if the results is seen by the end-user and not only used internally. Normally,METRIC_L2
in FAISS provides the euclidean distance before root-squaring. Then post-processing should apply the root-square. @cjnolet probably knows better about this though.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave these in expanded form, actually. It's the most used metric and the difference in performance is pretty huge