From d5bd840ff327b27f09febee96b827efea5de3ccc Mon Sep 17 00:00:00 2001 From: "Corey J. Nolet" Date: Wed, 23 Aug 2023 17:50:59 -0400 Subject: [PATCH] Using expanded distance computations in `pylibraft` (#1759) We are noticing a perf hit of nearly 2x with the unexpanded distance computations vs the expanded sitances. Authors: - Corey J. Nolet (https://github.com/cjnolet) Approvers: - Ben Frederickson (https://github.com/benfred) URL: https://github.com/rapidsai/raft/pull/1759 --- cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh | 5 +++-- python/pylibraft/pylibraft/distance/pairwise_distance.pyx | 6 +++--- python/pylibraft/pylibraft/test/test_brute_force.py | 2 +- python/pylibraft/pylibraft/test/test_distance.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh index 95577fd311..5e93d9e33b 100644 --- a/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh +++ b/cpp/include/raft/distance/detail/distance_ops/l2_exp.cuh @@ -16,6 +16,7 @@ #pragma once +#include #include // DI namespace raft::distance::detail::ops { @@ -33,7 +34,7 @@ struct l2_exp_cutlass_op { // outVal could be negative due to numerical instability, especially when // calculating self distance. // clamp to 0 to avoid potential NaN in sqrt - outVal = outVal * (outVal > DataT(0.0)); + outVal = outVal * (raft::abs(outVal) >= DataT(0.0001)); return sqrt ? raft::sqrt(outVal) : outVal; } @@ -88,7 +89,7 @@ struct l2_exp_distance_op { DataT val = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j]; // val could be negative due to numerical instability, especially when // calculating self distance. Clamp to 0 to avoid potential NaN in sqrt - acc[i][j] = val * (val > DataT(0.0)); + acc[i][j] = val * (raft::abs(val) >= DataT(0.0001)); } } if (sqrt) { diff --git a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx index 3037b9a725..20dadf0275 100644 --- a/python/pylibraft/pylibraft/distance/pairwise_distance.pyx +++ b/python/pylibraft/pylibraft/distance/pairwise_distance.pyx @@ -60,9 +60,9 @@ cdef extern from "raft_runtime/distance/pairwise_distance.hpp" \ float metric_arg) except + DISTANCE_TYPES = { - "l2": DistanceType.L2SqrtUnexpanded, - "sqeuclidean": DistanceType.L2Unexpanded, - "euclidean": DistanceType.L2SqrtUnexpanded, + "l2": DistanceType.L2SqrtExpanded, + "sqeuclidean": DistanceType.L2Expanded, + "euclidean": DistanceType.L2SqrtExpanded, "l1": DistanceType.L1, "cityblock": DistanceType.L1, "inner_product": DistanceType.InnerProduct, diff --git a/python/pylibraft/pylibraft/test/test_brute_force.py b/python/pylibraft/pylibraft/test/test_brute_force.py index 2e118d210d..42095c3b9f 100644 --- a/python/pylibraft/pylibraft/test/test_brute_force.py +++ b/python/pylibraft/pylibraft/test/test_brute_force.py @@ -89,7 +89,7 @@ def test_knn(n_index_rows, n_query_rows, n_cols, k, inplace, metric, dtype): cpu_ordered = pw_dists[i, expected_indices] np.testing.assert_allclose( - cpu_ordered[:k], gpu_dists, atol=1e-4, rtol=1e-4 + cpu_ordered[:k], gpu_dists, atol=1e-3, rtol=1e-3 ) diff --git a/python/pylibraft/pylibraft/test/test_distance.py b/python/pylibraft/pylibraft/test/test_distance.py index 2c0a842fe5..f9d3890ff7 100644 --- a/python/pylibraft/pylibraft/test/test_distance.py +++ b/python/pylibraft/pylibraft/test/test_distance.py @@ -81,4 +81,4 @@ def test_distance(n_rows, n_cols, inplace, metric, order, dtype): actual[actual <= 1e-5] = 0.0 - assert np.allclose(expected, actual, rtol=1e-4) + assert np.allclose(expected, actual, atol=1e-3, rtol=1e-3)