Skip to content

Commit

Permalink
Some fixes to pairwise distances for cupy integration (#643)
Browse files Browse the repository at this point in the history
Authors:
  - Corey J. Nolet (https://github.com/cjnolet)
  - Vinay Deshpande (https://github.com/vinaydes)

Approvers:
  - Divye Gala (https://github.com/divyegala)
  - Dante Gama Dessavre (https://github.com/dantegd)

URL: #643
  • Loading branch information
cjnolet authored May 11, 2022
1 parent 270bf95 commit d151ed8
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 18 deletions.
48 changes: 37 additions & 11 deletions python/pylibraft/pylibraft/distance/pairwise_distance.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,13 @@ from libcpp cimport bool
from .distance_type cimport DistanceType
from pylibraft.common.handle cimport handle_t


def is_c_cont(cai, dt):
return "strides" not in cai or \
cai["strides"] is None or \
cai["strides"][1] == dt.itemsize


cdef extern from "raft_distance/pairwise_distance.hpp" \
namespace "raft::distance::runtime":

Expand Down Expand Up @@ -54,12 +61,14 @@ cdef extern from "raft_distance/pairwise_distance.hpp" \

DISTANCE_TYPES = {
"l2": DistanceType.L2SqrtUnexpanded,
"sqeuclidean": DistanceType.L2Unexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"l1": DistanceType.L1,
"cityblock": DistanceType.L1,
"inner_product": DistanceType.InnerProduct,
"chebyshev": DistanceType.Linf,
"canberra": DistanceType.Canberra,
"cosine": DistanceType.CosineExpanded,
"lp": DistanceType.LpUnexpanded,
"correlation": DistanceType.CorrelationExpanded,
"jaccard": DistanceType.JaccardExpanded,
Expand All @@ -68,21 +77,26 @@ DISTANCE_TYPES = {
"jensenshannon": DistanceType.JensenShannon,
"hamming": DistanceType.HammingUnexpanded,
"kl_divergence": DistanceType.KLDivergence,
"minkowski": DistanceType.LpUnexpanded,
"russellrao": DistanceType.RusselRaoExpanded,
"dice": DistanceType.DiceExpanded
}

SUPPORTED_DISTANCES = list(DISTANCE_TYPES.keys())
SUPPORTED_DISTANCES = ["euclidean", "l1", "cityblock", "l2", "inner_product",
"chebyshev", "minkowski", "canberra", "kl_divergence",
"correlation", "russellrao", "hellinger", "lp",
"hamming", "jensenshannon", "cosine", "sqeuclidean"]


def distance(X, Y, dists, metric="euclidean"):
def distance(X, Y, dists, metric="euclidean", p=2.0):
"""
Compute pairwise distances between X and Y
Valid values for metric:
["euclidean", "l2", "l1", "cityblock", "inner_product",
"chebyshev", "canberra", "lp", "hellinger", "jensenshannon",
"kl_divergence", "russellrao"]
"kl_divergence", "russellrao", "minkowski", "correlation",
"cosine"]
Parameters
----------
Expand All @@ -91,6 +105,7 @@ def distance(X, Y, dists, metric="euclidean"):
Y : CUDA array interface compliant matrix shape (n, k)
dists : Writable CUDA array interface matrix shape (m, n)
metric : string denoting the metric type (default="euclidean")
p : metric parameter (currently used only for "minkowski")
Examples
--------
Expand All @@ -113,14 +128,19 @@ def distance(X, Y, dists, metric="euclidean"):
pairwise_distance(in1, in2, output, metric="euclidean")
"""

# TODO: Validate inputs, shapes, etc...
x_cai = X.__cuda_array_interface__
y_cai = Y.__cuda_array_interface__
dists_cai = dists.__cuda_array_interface__

m = x_cai["shape"][0]
n = y_cai["shape"][0]
k = x_cai["shape"][1]

x_k = x_cai["shape"][1]
y_k = y_cai["shape"][1]

if x_k != y_k:
raise ValueError("Inputs must have same number of columns. "
"a=%s, b=%s" % (x_k, y_k))

x_ptr = <uintptr_t>x_cai["data"][0]
y_ptr = <uintptr_t>y_cai["data"][0]
Expand All @@ -132,6 +152,12 @@ def distance(X, Y, dists, metric="euclidean"):
y_dt = np.dtype(y_cai["typestr"])
d_dt = np.dtype(dists_cai["typestr"])

x_c_contiguous = is_c_cont(x_cai, x_dt)
y_c_contiguous = is_c_cont(y_cai, y_dt)

if x_c_contiguous != y_c_contiguous:
raise ValueError("Inputs must have matching strides")

if metric not in SUPPORTED_DISTANCES:
raise ValueError("metric %s is not supported" % metric)

Expand All @@ -147,20 +173,20 @@ def distance(X, Y, dists, metric="euclidean"):
<float*> d_ptr,
<int>m,
<int>n,
<int>k,
<int>x_k,
<DistanceType>distance_type,
<bool>True,
<float>0.0)
<bool>x_c_contiguous,
<float>p)
elif x_dt == np.float64:
pairwise_distance(deref(h),
<double*> x_ptr,
<double*> y_ptr,
<double*> d_ptr,
<int>m,
<int>n,
<int>k,
<int>x_k,
<DistanceType>distance_type,
<bool>True,
<float>0.0)
<bool>x_c_contiguous,
<float>p)
else:
raise ValueError("dtype %s not supported" % x_dt)
17 changes: 10 additions & 7 deletions python/pylibraft/pylibraft/test/test_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@

class TestDeviceBuffer:

def __init__(self, ndarray):
def __init__(self, ndarray, order):
self.ndarray_ = ndarray
self.device_buffer_ = \
rmm.DeviceBuffer.to_device(ndarray.ravel(order="C").tobytes())
rmm.DeviceBuffer.to_device(ndarray.ravel(order=order).tobytes())

@property
def __cuda_array_interface__(self):
Expand All @@ -49,10 +49,13 @@ def copy_to_host(self):
@pytest.mark.parametrize("n_cols", [100])
@pytest.mark.parametrize("metric", ["euclidean", "cityblock", "chebyshev",
"canberra", "correlation", "hamming",
"jensenshannon", "russellrao"])
"jensenshannon", "russellrao", "cosine",
"sqeuclidean"])
@pytest.mark.parametrize("order", ["F", "C"])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_distance(n_rows, n_cols, metric, dtype):
input1 = np.random.random_sample((n_rows, n_cols)).astype(dtype)
def test_distance(n_rows, n_cols, metric, order, dtype):
input1 = np.random.random_sample((n_rows, n_cols))
input1 = np.asarray(input1, order=order).astype(dtype)

# RussellRao expects boolean arrays
if metric == "russellrao":
Expand All @@ -70,8 +73,8 @@ def test_distance(n_rows, n_cols, metric, dtype):

expected[expected <= 1e-5] = 0.0

input1_device = TestDeviceBuffer(input1)
output_device = TestDeviceBuffer(output)
input1_device = TestDeviceBuffer(input1, order)
output_device = TestDeviceBuffer(output, order)

pairwise_distance(input1_device, input1_device, output_device, metric)
actual = output_device.copy_to_host()
Expand Down

0 comments on commit d151ed8

Please sign in to comment.