Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes to pairwise distances for cupy integration #643

Merged
merged 11 commits into from
May 11, 2022
32 changes: 24 additions & 8 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ cdef extern from "raft_distance/pairwise_distance.hpp" \

DISTANCE_TYPES = {
"l2": DistanceType.L2SqrtUnexpanded,
"sqeuclidean": DistanceType.L2Unexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"l1": DistanceType.L1,
"cityblock": DistanceType.L1,
"inner_product": DistanceType.InnerProduct,
"chebyshev": DistanceType.Linf,
"canberra": DistanceType.Canberra,
"cosine": DistanceType.CosineExpanded,
"lp": DistanceType.LpUnexpanded,
"correlation": DistanceType.CorrelationExpanded,
"jaccard": DistanceType.JaccardExpanded,
Expand All @@ -68,21 +70,26 @@ DISTANCE_TYPES = {
"jensenshannon": DistanceType.JensenShannon,
"hamming": DistanceType.HammingUnexpanded,
"kl_divergence": DistanceType.KLDivergence,
"minkowski": DistanceType.LpUnexpanded,
"russellrao": DistanceType.RusselRaoExpanded,
"dice": DistanceType.DiceExpanded
}

SUPPORTED_DISTANCES = list(DISTANCE_TYPES.keys())
SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",
"chebyshev", "minkowski", "canberra", "kl_divergence",
"correlation", "russellrao", "hellinger", "lp",
"hamming", "jensenshannon", "cosine", "sqeuclidean"]
cjnolet marked this conversation as resolved.
Show resolved Hide resolved


def distance(X, Y, dists, metric="euclidean"):
def distance(X, Y, dists, metric="euclidean", p=2.0):
cjnolet marked this conversation as resolved.
Show resolved Hide resolved
"""
Compute pairwise distances between X and Y

Valid values for metric:
["euclidean", "l2", "l1", "cityblock", "inner_product",
"chebyshev", "canberra", "lp", "hellinger", "jensenshannon",
"kl_divergence", "russellrao"]
"kl_divergence", "russellrao", "minkowski", "correlation",
"cosine"]

Parameters
----------
Expand Down Expand Up @@ -113,13 +120,16 @@ def distance(X, Y, dists, metric="euclidean"):
pairwise_distance(in1, in2, output, metric="euclidean")
"""

# TODO: Validate inputs, shapes, etc...
x_cai = X.__cuda_array_interface__
y_cai = Y.__cuda_array_interface__
dists_cai = dists.__cuda_array_interface__

m = x_cai["shape"][0]
n = y_cai["shape"][0]

if x_cai["shape"][1] != y_cai["shape"][1]:
raise ValueError("Inputs must have same number of columns.")
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

k = x_cai["shape"][1]

x_ptr = <uintptr_t>x_cai["data"][0]
Expand All @@ -132,6 +142,12 @@ def distance(X, Y, dists, metric="euclidean"):
y_dt = np.dtype(y_cai["typestr"])
d_dt = np.dtype(dists_cai["typestr"])

x_c_contiguous = "strides" not in x_cai or x_cai["strides"] is None
y_c_contiguous = "strides" not in y_cai or y_cai["strides"] is None

if x_c_contiguous != y_c_contiguous:
raise ValueError("Inputs must have matching strides")
cjnolet marked this conversation as resolved.
Show resolved Hide resolved

if metric not in SUPPORTED_DISTANCES:
raise ValueError("metric %s is not supported" % metric)

Expand All @@ -149,8 +165,8 @@ def distance(X, Y, dists, metric="euclidean"):
<int>n,
<int>k,
<DistanceType>distance_type,
<bool>True,
<float>0.0)
<bool>x_c_contiguous,
<float>p)
elif x_dt == np.float64:
pairwise_distance(deref(h),
<double*> x_ptr,
Expand All @@ -160,7 +176,7 @@ def distance(X, Y, dists, metric="euclidean"):
<int>n,
<int>k,
<DistanceType>distance_type,
<bool>True,
<float>0.0)
<bool>x_c_contiguous,
<float>p)
else:
raise ValueError("dtype %s not supported" % x_dt)
3 changes: 2 additions & 1 deletion python/pylibraft/pylibraft/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ def copy_to_host(self):
@pytest.mark.parametrize("n_cols", [100])
@pytest.mark.parametrize("metric", ["euclidean", "cityblock", "chebyshev",
"canberra", "correlation", "hamming",
"jensenshannon", "russellrao"])
"jensenshannon", "russellrao", "cosine",
"sqeuclidean"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_distance(n_rows, n_cols, metric, dtype):
input1 = np.random.random_sample((n_rows, n_cols)).astype(dtype)
Expand Down