-
Notifications
You must be signed in to change notification settings - Fork 540
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Replace ML::MetricType with raft::distance::DistanceType #3389
Replace ML::MetricType with raft::distance::DistanceType #3389
Conversation
6529592
to
c96a8c4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just have a few suggestions.
cpp/include/cuml/neighbors/knn.hpp
Outdated
MetricType metric = MetricType::METRIC_L2, | ||
float metric_arg = 2.0f, bool expanded = false); | ||
raft::distance::DistanceType metric = | ||
raft::distance::DistanceType::L2Unexpanded, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we may want L2SqrtUnexpanded
here to have the euclidean distance as default. At least if the results is seen by the end-user and not only used internally. Normally, METRIC_L2
in FAISS provides the euclidean distance before root-squaring. Then post-processing should apply the root-square. @cjnolet probably knows better about this though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave these in expanded form, actually. It's the most used metric and the difference in performance is pretty huge
ML::MetricType metric = ML::MetricType::METRIC_L2, | ||
float metricArg = 0, bool expanded_form = false); | ||
raft::distance::DistanceType metric = | ||
raft::distance::DistanceType::L2Unexpanded, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, L2SqrtUnexpanded
might be needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same here, I would leave this in expanded form.
cpp/src/knn/knn_api.cpp
Outdated
@@ -71,7 +71,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes, | |||
try { | |||
ML::brute_force_knn(*handle_ptr, input_vec, sizes_vec, D, search_items, n, | |||
res_I, res_D, k, rowMajorIndex, rowMajorQuery, | |||
(ML::MetricType)metric_type, metric_arg, expanded); | |||
(raft::distance::DistanceType)metric_type, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we need to keep making this conversion explicit?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're using the same enum type everywhere now, I think this conversion can be removed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This conversion from int to DistanceType
is needed to not modify the C API of knn_search
. I changed it to a static_cast
.
@@ -89,7 +89,7 @@ void get_distances(const raft::handle_t &handle, | |||
k_graph.knn_indices, k_graph.knn_dists, k_graph.n_neighbors, | |||
handle.get_cusparse_handle(), handle.get_device_allocator(), stream, | |||
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, | |||
ML::MetricType::METRIC_L2); | |||
raft::distance::DistanceType::L2Expanded); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need L2SqrtUnexpanded
here, unless this distance value doesn't reach user's eye (only used by TSNE internally).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This doesn't reach the users so I think it's okay not to use the sqrt here. I think the expanded form is also good to use here for speed.
@@ -91,7 +92,7 @@ void launcher(const raft::handle_t &handle, | |||
inputsB.n, inputsB.d, out.knn_indices, out.knn_dists, n_neighbors, | |||
handle.get_cusparse_handle(), d_alloc, stream, | |||
ML::Sparse::DEFAULT_BATCH_SIZE, ML::Sparse::DEFAULT_BATCH_SIZE, | |||
ML::MetricType::METRIC_L2); | |||
raft::distance::DistanceType::L2Expanded); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, L2SqrtUnexpanded
might be needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would leave all of these in expanded form. The unexpanded is more stable under some conditions but in general it's a better (and faster) starting point.
if (metric == raft::distance::DistanceType::L2SqrtExpanded || | ||
metric == raft::distance::DistanceType::L2SqrtUnexpanded || | ||
metric == raft::distance::DistanceType::LpUnexpanded) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Might be wrong, but I think that in this case, only unexpanded forms (that need post-procesing) : L2SqrtUnexpanded
and LpUnexpanded
should have post-processing. @cjnolet probably knows more about this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FAISS only supports the expanded form but I believe we're converting both the Expanded
and Unexpanded
L2 forms into faiss::METRIC_L2
so we'll need to sqrt both of them.
ML::MetricType metric_ = ML::MetricType::METRIC_L2, | ||
float metricArg_ = 0, bool expanded_form_ = false) | ||
raft::distance::DistanceType metric_ = | ||
raft::distance::DistanceType::L2Unexpanded, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, L2SqrtUnexpanded
might be needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would revert this to expanded form as well.
ML::MetricType metric = ML::MetricType::METRIC_L2, | ||
float metricArg = 0, bool expanded_form = false) { | ||
raft::distance::DistanceType metric = | ||
raft::distance::DistanceType::L2Unexpanded, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, L2SqrtUnexpanded
might be needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And here.
if metric == "euclidean" or metric == "l2": | ||
m = MetricType.METRIC_L2 | ||
m = DistanceType.L2SqrtExpanded |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same, L2SqrtUnexpanded
might be needed here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Expanded version is preferred for speed
elif metric == "cityblock" or metric == "l1"\ | ||
or metric == "manhattan" or metric == 'taxicab': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can maybe be replaced by elif metric in [..., ...]:
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
@@ -57,6 +57,8 @@ cumlError_t knn_search(const cumlHandle_t handle, float **input, int *sizes, | |||
cumlError_t status; | |||
raft::handle_t *handle_ptr; | |||
std::tie(handle_ptr, status) = ML::handleMap.lookupHandlePointer(handle); | |||
raft::distance::DistanceType metric_distance_type = |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like this. The intent becomes more clear.
@@ -434,8 +378,12 @@ class sparse_knn_t { | |||
dist_config.allocator = allocator; | |||
dist_config.stream = stream; | |||
|
|||
raft::sparse::distance::pairwiseDistance(batch_dists, dist_config, | |||
get_pw_metric(), metricArg); | |||
if (raft::sparse::distance::supportedDistance.find(metric) == |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I like the introduction of the explicit set for supported distances. For a follow-on / future PR, we might want to consider just using a hash map to map the distance enum type to its “distances_t” implementation, which will allow us to get rid of the switch statement altogether.
@gpucibot merge |
Adding a reminder / note here that the dense & sparse knn primitives will need to be updated in RAFT before #3476 goes in. |
rerun tests |
@lowener, according to the logs, it looks like there was a gtest failure in
|
@gpucibot merge |
rerun tests |
Codecov Report
@@ Coverage Diff @@
## branch-0.19 #3389 +/- ##
===============================================
+ Coverage 71.77% 80.80% +9.02%
===============================================
Files 212 227 +15
Lines 17075 17735 +660
===============================================
+ Hits 12256 14331 +2075
+ Misses 4819 3404 -1415
Flags with carried forward coverage won't be shown. Click here to find out more.
Continue to review full report at Codecov.
|
rerun tests |
Closes #3319.
This PR will replace the distance type from ML::MetricType to raft::distance::DistanceType.
Since Raft DistanceType makes the distinction between the expanded and non-expanded distances in the name, I changed the C++ API to remove the boolean parameter
expanded
which becomes useless.