Skip to content

Commit

Permalink
Using expanded distance computations in pylibraft (rapidsai#1759)
Browse files Browse the repository at this point in the history
We are noticing a perf hit of nearly 2x with the unexpanded distance computations vs the expanded sitances.

Authors:
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Ben Frederickson (https://github.com/benfred)

URL: rapidsai#1759
  • Loading branch information
cjnolet authored Aug 23, 2023
1 parent 9019054 commit d5bd840
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
5 changes: 3 additions & 2 deletions cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <raft/core/math.hpp>
#include <raft/util/cuda_dev_essentials.cuh> // DI

namespace raft::distance::detail::ops {
Expand All @@ -33,7 +34,7 @@ struct l2_exp_cutlass_op {
// outVal could be negative due to numerical instability, especially when
// calculating self distance.
// clamp to 0 to avoid potential NaN in sqrt
outVal = outVal * (outVal > DataT(0.0));
outVal = outVal * (raft::abs(outVal) >= DataT(0.0001));
return sqrt ? raft::sqrt(outVal) : outVal;
}

Expand Down Expand Up @@ -88,7 +89,7 @@ struct l2_exp_distance_op {
DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j];
// val could be negative due to numerical instability, especially when
// calculating self distance. Clamp to 0 to avoid potential NaN in sqrt
acc[i][j] = val * (val > DataT(0.0));
acc[i][j] = val * (raft::abs(val) >= DataT(0.0001));
}
}
if (sqrt) {
Expand Down
6 changes: 3 additions & 3 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \
float metric_arg) except +

DISTANCE_TYPES = {
"l2": DistanceType.L2SqrtUnexpanded,
"sqeuclidean": DistanceType.L2Unexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"l2": DistanceType.L2SqrtExpanded,
"sqeuclidean": DistanceType.L2Expanded,
"euclidean": DistanceType.L2SqrtExpanded,
"l1": DistanceType.L1,
"cityblock": DistanceType.L1,
"inner_product": DistanceType.InnerProduct,
Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/test/test_brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def test_knn(n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype):

cpu_ordered = pw_dists[i, expected_indices]
np.testing.assert_allclose(
cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4
cpu_ordered[:k], gpu_dists, atol=1e-3, rtol=1e-3
)


Expand Down
2 changes: 1 addition & 1 deletion python/pylibraft/pylibraft/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype):

actual[actual <= 1e-5] = 0.0

assert np.allclose(expected, actual, rtol=1e-4)
assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3)

0 comments on commit d5bd840

Please sign in to comment.