Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#4783 Added nan_euclidean distance metric to pairwise_distances #4797

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/cuml/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from cuml.metrics.cluster.entropy import cython_entropy as entropy
from cuml.metrics.pairwise_distances import pairwise_distances
from cuml.metrics.pairwise_distances import sparse_pairwise_distances
from cuml.metrics.pairwise_distances import nan_euclidean_distances
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_METRICS
from cuml.metrics.pairwise_distances import PAIRWISE_DISTANCE_SPARSE_METRICS
from cuml.metrics.pairwise_kernels import pairwise_kernels
Expand All @@ -59,6 +60,7 @@
"mutual_info_score",
"confusion_matrix",
"entropy",
"nan_euclidean_distances"
"pairwise_distances",
"sparse_pairwise_distances",
"pairwise_kernels",
Expand Down
132 changes: 128 additions & 4 deletions python/cuml/metrics/pairwise_distances.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ from raft.common.handle cimport handle_t
from raft.common.handle import Handle
import cupy as cp
import numpy as np
import pandas as pd
import cudf
import scipy
import cupyx
import cuml.internals
Expand All @@ -34,6 +36,7 @@ from cuml.common.sparse_utils import is_sparse
from cuml.common.array_sparse import SparseCumlArray
from cuml.metrics.cluster.utils import prepare_cluster_metric_inputs
from cuml.metrics.distance_type cimport DistanceType
from cuml.thirdparty_adapters import _get_mask

cdef extern from "cuml/metrics/metrics.hpp" namespace "ML::Metrics":
void pairwise_distance(const handle_t &handle, const double *x,
Expand Down Expand Up @@ -78,7 +81,8 @@ PAIRWISE_DISTANCE_METRICS = {
"jensenshannon": DistanceType.JensenShannon,
"hamming": DistanceType.HammingUnexpanded,
"kldivergence": DistanceType.KLDivergence,
"russellrao": DistanceType.RusselRaoExpanded
"russellrao": DistanceType.RusselRaoExpanded,
"nan_euclidean": DistanceType.L2Expanded
}

PAIRWISE_DISTANCE_SPARSE_METRICS = {
Expand Down Expand Up @@ -117,9 +121,6 @@ def _determine_metric(metric_str, is_sparse=False):
if metric_str == 'haversine':
raise ValueError(" The metric: '{}', is not supported at this time."
.format(metric_str))
elif metric_str == 'nan_euclidean':
raise ValueError(" The metric: '{}', is not supported at this time."
.format(metric_str))

if not(is_sparse) and (metric_str not in PAIRWISE_DISTANCE_METRICS):
if metric_str in PAIRWISE_DISTANCE_SPARSE_METRICS:
Expand All @@ -136,6 +137,126 @@ def _determine_metric(metric_str, is_sparse=False):
return PAIRWISE_DISTANCE_METRICS[metric_str]


def nan_euclidean_distances(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add "nan_euclidean" to PAIRWISE_DISTANCE_METRICS

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you Corey for all the changes suggested will do the relevant changes and commit again.

X, Y=None, *, squared=False, missing_values=cp.nan
):
"""Calculate the euclidean distances in the presence of missing values.

Compute the euclidean distance between each pair of samples in X and Y,
where Y=X is assumed if Y=None. When calculating the distance between a
pair of samples, this formulation ignores feature coordinates with a
missing value in either sample and scales up the weight of the remaining
coordinates:

dist(x,y) = sqrt(weight * sq. distance from present coordinates)
where,
weight = Total # of coordinates / # of present coordinates

For example, the distance between ``[3, na, na, 6]`` and ``[1, na, 4, 5]``
is:

.. math::
\\sqrt{\\frac{4}{2}((3-1)^2 + (6-5)^2)}

If all the coordinates are missing or if there are no common present
coordinates then NaN is returned for that pair.

Parameters
----------
X : Dense matrix of shape (n_samples_X, n_features)
Acceptable formats: cuDF DataFrame, Pandas DataFrame, NumPy ndarray,
cuda array interface compliant array like CuPy.

Y : Dense matrix of shape (n_samples_Y, n_features), default=None
Acceptable formats: cuDF DataFrame, Pandas DataFrame, NumPy ndarray,
cuda array interface compliant array like CuPy.

squared : bool, default=False
Return squared Euclidean distances.

missing_values : np.nan or int, default=np.nan
Representation of missing value.

Returns
-------
distances : ndarray of shape (n_samples_X, n_samples_Y)
Returns the distances between the row vectors of `X`
and the row vectors of `Y`.
"""

if isinstance(X, cudf.DataFrame) or isinstance(X, pd.DataFrame):
if (X.isnull().any()).any():
X.fillna(0, inplace=True)

if isinstance(Y, cudf.DataFrame) or isinstance(Y, pd.DataFrame):
if (Y.isnull().any()).any():
Y.fillna(0, inplace=True)

X_m, n_samples_x, n_features_x, dtype_x = \
input_to_cuml_array(X, order="K", check_dtype=[np.float32, np.float64])

if Y is None:
Y = X_m

Y_m, n_samples_y, n_features_y, dtype_y = \
input_to_cuml_array(
Y, order=X_m.order, convert_to_dtype=dtype_x,
check_dtype=[dtype_x])

X_m = cp.asarray(X_m)
Y_m = cp.asarray(Y_m)

# Get missing mask for X
missing_X = _get_mask(X_m, missing_values)

# Get missing mask for Y
missing_Y = missing_X if Y is X else _get_mask(Y_m, missing_values)

# set missing values to zero
X_m[missing_X] = 0
Y_m[missing_Y] = 0

# Adjust distances for sqaured
if X_m.shape == Y_m.shape:
if (X_m == Y_m).all():
distances = cp.asarray(pairwise_distances(
X_m, metric="sqeuclidean"))
else:
distances = cp.asarray(pairwise_distances(
X_m, Y_m, metric="sqeuclidean"))
else:
distances = cp.asarray(pairwise_distances(
X_m, Y_m, metric="sqeuclidean"))

# Adjust distances for missing values
XX = X_m * X_m
YY = Y_m * Y_m
distances -= cp.dot(XX, missing_Y.T)
distances -= cp.dot(missing_X, YY.T)

cp.clip(distances, 0, None, out=distances)

if X_m is Y_m:
# Ensure that distances between vectors and themselves are set to 0.0.
# This may not be the case due to floating point rounding errors.
cp.fill_diagonal(distances, 0.0)

present_X = 1 - missing_X
present_Y = present_X if Y_m is X_m else ~missing_Y
present_count = cp.dot(present_X, present_Y.T)
distances[present_count == 0] = cp.nan

# avoid divide by zero
cp.maximum(1, present_count, out=present_count)
distances /= present_count
distances *= X_m.shape[1]

if not squared:
cp.sqrt(distances, out=distances)

return distances


@cuml.internals.api_return_array(get_output_type=True)
def pairwise_distances(X, Y=None, metric="euclidean", handle=None,
convert_dtype=True, metric_arg=2, **kwds):
Expand Down Expand Up @@ -225,6 +346,9 @@ def pairwise_distances(X, Y=None, metric="euclidean", handle=None,
handle = Handle() if handle is None else handle
cdef handle_t *handle_ = <handle_t*> <size_t> handle.getHandle()

if metric in ['nan_euclidean']:
return nan_euclidean_distances(X, Y, **kwds)

if metric in ['russellrao'] and not np.all(X.data == 1.):
warnings.warn("X was converted to boolean for metric {}"
.format(metric))
Expand Down
9 changes: 5 additions & 4 deletions python/cuml/tests/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def test_sklearn_search():
def test_accuracy(nrows, ncols, n_info, datatype):

use_handle = True
train_rows = np.int32(nrows*0.8)
train_rows = np.int32(nrows * 0.8)
X, y = make_classification(n_samples=nrows, n_features=ncols,
n_clusters_per_class=1, n_informative=n_info,
random_state=123, n_classes=5)
Expand Down Expand Up @@ -251,7 +251,7 @@ def test_rand_index_score(name, nrows):
def test_silhouette_score_batched(metric, chunk_divider, labeled_clusters):
X, labels = labeled_clusters
cuml_score = cu_silhouette_score(X, labels, metric=metric,
chunksize=int(X.shape[0]/chunk_divider))
chunksize=int(X.shape[0] / chunk_divider))
sk_score = sk_silhouette_score(X, labels, metric=metric)
assert_almost_equal(cuml_score, sk_score, decimal=2)

Expand Down Expand Up @@ -988,6 +988,7 @@ def test_pairwise_distances_sklearn_comparison(metric: str, matrix_size):
# For fp64, compare at 10 decimals, (5 places less than the ~15 max)
compare_precision = 10

print(X.shape, Y.shape, metric)
# Compare to sklearn, fp64
S = pairwise_distances(X, Y, metric=metric)

Expand Down Expand Up @@ -1076,7 +1077,7 @@ def test_pairwise_distances_one_dimension_order(metric: str):
cp.testing.assert_array_almost_equal(S, S2, decimal=compare_precision)


@pytest.mark.parametrize("metric", ["haversine", "nan_euclidean"])
@pytest.mark.parametrize("metric", ["haversine"])
def test_pairwise_distances_unsuppored_metrics(metric):
rng = np.random.RandomState(3)

Expand Down Expand Up @@ -1397,7 +1398,7 @@ def test_sparse_pairwise_distances_output_types(input_type, output_type):
@pytest.mark.parametrize("input_type", ["cudf", "cupy"])
@pytest.mark.parametrize("n_classes", [2, 5])
def test_hinge_loss(nrows, ncols, n_info, input_type, n_classes):
train_rows = np.int32(nrows*0.8)
train_rows = np.int32(nrows * 0.8)
X, y = make_classification(n_samples=nrows, n_features=ncols,
n_clusters_per_class=1, n_informative=n_info,
random_state=123, n_classes=n_classes)
Expand Down