Skip to content

Commit

Permalink
Fix UMAP and simplicial set functions metric (#5490)
Browse files Browse the repository at this point in the history
Answers #5422

Authors:
  - Victor Lafargue (https://github.com/viclafargue)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Simon Adorf (https://github.com/csadorf)

Approvers:
  - Simon Adorf (https://github.com/csadorf)
  - William Hicks (https://github.com/wphicks)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #5490
  • Loading branch information
viclafargue authored Aug 3, 2023
1 parent 6c8c5ef commit 6bf61ca
Show file tree
Hide file tree
Showing 5 changed files with 170 additions and 77 deletions.
3 changes: 3 additions & 0 deletions cpp/src/umap/knn_graph/algo.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,9 @@ inline void launcher(const raft::handle_t& handle,
out.knn_indices,
out.knn_dists,
n_neighbors,
true,
true,
static_cast<std::vector<int64_t>*>(nullptr),
params->metric,
params->p);
}
Expand Down
49 changes: 40 additions & 9 deletions python/cuml/manifold/simpl_set.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ from cuml.internals.safe_imports import gpu_only_import
cp = gpu_only_import('cupy')

from cuml.manifold.umap_utils cimport *
from cuml.manifold.umap_utils import GraphHolder, find_ab_params
from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \
metric_parsing

from cuml.internals.input_utils import input_to_cuml_array
from cuml.internals.array import CumlArray
Expand Down Expand Up @@ -82,10 +83,17 @@ def fuzzy_simplicial_set(X,
structure to the detriment of the larger picture.
random_state: numpy RandomState or equivalent
A state capable being used as a numpy random state.
metric: string or function (optional, default 'euclidean')
unused
metric_kwds: dict (optional, default {})
unused
metric: string (default='euclidean').
Distance metric to use. Supported distances are ['l1, 'cityblock',
'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra',
'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger',
'hamming', 'jaccard']
Metrics that take arguments (such as minkowski) can have arguments
passed via the metric_kwds dictionary.
Note: The 'jaccard' distance metric is only supported for sparse
inputs.
metric_kwds: dict (optional, default=None)
Metric argument
knn_indices: array of shape (n_samples, n_neighbors) (optional)
If the k-nearest neighbors of each point has already been calculated
you can pass them in here to save computation time. This should be
Expand Down Expand Up @@ -138,6 +146,14 @@ def fuzzy_simplicial_set(X,
umap_params.deterministic = <bool> deterministic
umap_params.set_op_mix_ratio = <float> set_op_mix_ratio
umap_params.local_connectivity = <float> local_connectivity
try:
umap_params.metric = metric_parsing[metric.lower()]
except KeyError:
raise ValueError(f"Invalid value for metric: {metric}")
if metric_kwds is None:
umap_params.p = <float> 2.0
else:
umap_params.p = <float> metric_kwds.get("p", 2.0)
umap_params.verbosity = <int> verbose

X_m, _, _, _ = \
Expand Down Expand Up @@ -245,10 +261,17 @@ def simplicial_set_embedding(
* A numpy array of initial embedding positions.
random_state: numpy RandomState or equivalent
A state capable being used as a numpy random state.
metric: string or callable
unused
metric_kwds: dict
unused
metric: string (default='euclidean').
Distance metric to use. Supported distances are ['l1, 'cityblock',
'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra',
'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger',
'hamming', 'jaccard']
Metrics that take arguments (such as minkowski) can have arguments
passed via the metric_kwds dictionary.
Note: The 'jaccard' distance metric is only supported for sparse
inputs.
metric_kwds: dict (optional, default=None)
Metric argument
output_metric: function
Function returning the distance between two points in embedding space
and the gradient of the distance wrt the first argument.
Expand Down Expand Up @@ -306,6 +329,14 @@ def simplicial_set_embedding(
umap_params.init = <int> 0
umap_params.random_state = <int> random_state
umap_params.deterministic = <bool> deterministic
try:
umap_params.metric = metric_parsing[metric.lower()]
except KeyError:
raise ValueError(f"Invalid value for metric: {metric}")
if metric_kwds is None:
umap_params.p = <float> 2.0
else:
umap_params.p = <float> metric_kwds.get("p", 2.0)
if output_metric == 'euclidean':
umap_params.target_metric = MetricType.EUCLIDEAN
else: # output_metric == 'categorical'
Expand Down
53 changes: 20 additions & 33 deletions python/cuml/manifold/umap.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ cupyx = gpu_only_import('cupyx')
cuda = gpu_only_import('numba.cuda')

from cuml.manifold.umap_utils cimport *
from cuml.manifold.umap_utils import GraphHolder, find_ab_params
from cuml.manifold.umap_utils import GraphHolder, find_ab_params, \
metric_parsing, DENSE_SUPPORTED_METRICS, SPARSE_SUPPORTED_METRICS

from cuml.common.sparsefuncs import extract_knn_infos
from cuml.internals.safe_imports import gpu_only_import_from
Expand All @@ -47,7 +48,6 @@ from cuml.internals.array import CumlArray
from cuml.internals.array_sparse import SparseCumlArray
from cuml.internals.mixins import CMajorInputTagMixin
from cuml.common.sparse_utils import is_sparse
from cuml.metrics.distance_type cimport DistanceType

from cuml.manifold.simpl_set import fuzzy_simplicial_set # no-cython-lint
from cuml.manifold.simpl_set import simplicial_set_embedding # no-cython-lint
Expand Down Expand Up @@ -152,13 +152,17 @@ class UMAP(UniversalBase,
n_components: int (optional, default 2)
The dimension of the space to embed into. This defaults to 2 to
provide easy visualization, but can reasonably be set to any
metric : string (default='euclidean').
metric: string (default='euclidean').
Distance metric to use. Supported distances are ['l1, 'cityblock',
'taxicab', 'manhattan', 'euclidean', 'l2', 'sqeuclidean', 'canberra',
'minkowski', 'chebyshev', 'linf', 'cosine', 'correlation', 'hellinger',
'hamming', 'jaccard']
Metrics that take arguments (such as minkowski) can have arguments
passed via the metric_kwds dictionary.
Note: The 'jaccard' distance metric is only supported for sparse
inputs.
metric_kwds: dict (optional, default=None)
Metric argument
n_epochs: int (optional, default None)
The number of training epochs to be used in optimizing the
low dimensional embedding. Larger values result in more accurate
Expand Down Expand Up @@ -419,7 +423,7 @@ class UMAP(UniversalBase,
raise ValueError("min_dist should be <= spread")

@staticmethod
def _build_umap_params(cls):
def _build_umap_params(cls, sparse):
cdef UMAPParams* umap_params = new UMAPParams()
umap_params.n_neighbors = <int> cls.n_neighbors
umap_params.n_components = <int> cls.n_components
Expand Down Expand Up @@ -448,37 +452,20 @@ class UMAP(UniversalBase,
umap_params.random_state = <uint64_t> cls.random_state
umap_params.deterministic = <bool> cls.deterministic

# metric
metric_parsing = {
"l2": DistanceType.L2SqrtUnexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"sqeuclidean": DistanceType.L2Unexpanded,
"cityblock": DistanceType.L1,
"l1": DistanceType.L1,
"manhattan": DistanceType.L1,
"taxicab": DistanceType.L1,
"minkowski": DistanceType.LpUnexpanded,
"chebyshev": DistanceType.Linf,
"linf": DistanceType.Linf,
"cosine": DistanceType.CosineExpanded,
"correlation": DistanceType.CorrelationExpanded,
"hellinger": DistanceType.HellingerExpanded,
"hamming": DistanceType.HammingUnexpanded,
"jaccard": DistanceType.JaccardExpanded,
"canberra": DistanceType.Canberra
}

if cls.metric.lower() in metric_parsing:
try:
umap_params.metric = metric_parsing[cls.metric.lower()]
else:
raise ValueError("Invalid value for metric: {}"
.format(cls.metric))

if sparse:
if umap_params.metric not in SPARSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{cls.metric}' not supported for sparse inputs.")
elif umap_params.metric not in DENSE_SUPPORTED_METRICS:
raise NotImplementedError(f"Metric '{cls.metric}' not supported for dense inputs.")

except KeyError:
raise ValueError(f"Invalid value for metric: {cls.metric}")
if cls.metric_kwds is None:
umap_params.p = <float> 2.0
else:
umap_params.p = <float>cls.metric_kwds.get('p')

umap_params.p = <float> cls.metric_kwds.get("p", 2.0)
cdef uintptr_t callback_ptr = 0
if cls.callback:
callback_ptr = cls.callback.get_native_callback()
Expand Down Expand Up @@ -576,7 +563,7 @@ class UMAP(UniversalBase,
cdef uintptr_t embed_raw = self.embedding_.ptr

cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> UMAP._build_umap_params(self)
<UMAPParams*> <size_t> UMAP._build_umap_params(self, self.sparse_fit)

cdef uintptr_t y_raw = 0

Expand Down Expand Up @@ -742,7 +729,7 @@ class UMAP(UniversalBase,
cdef uintptr_t embed_ptr = self.embedding_.ptr

cdef UMAPParams* umap_params = \
<UMAPParams*> <size_t> UMAP._build_umap_params(self)
<UMAPParams*> <size_t> UMAP._build_umap_params(self, self.sparse_fit)

if self.sparse_fit:
transform_sparse(handle_[0],
Expand Down
51 changes: 51 additions & 0 deletions python/cuml/manifold/umap_utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from rmm._lib.memory_resource cimport get_current_device_resource
from pylibraft.common.handle cimport handle_t
from cuml.manifold.umap_utils cimport *
from cuml.metrics.distance_type cimport DistanceType
from libcpp.utility cimport move
from cuml.internals.safe_imports import cpu_only_import
np = cpu_only_import('numpy')
Expand Down Expand Up @@ -130,3 +131,53 @@ def find_ab_params(spread, min_dist):
yv[xv >= min_dist] = np.exp(-(xv[xv >= min_dist] - min_dist) / spread)
params, _ = curve_fit(curve, xv, yv)
return params[0], params[1]


metric_parsing = {
"l2": DistanceType.L2SqrtUnexpanded,
"euclidean": DistanceType.L2SqrtUnexpanded,
"sqeuclidean": DistanceType.L2Unexpanded,
"cityblock": DistanceType.L1,
"l1": DistanceType.L1,
"manhattan": DistanceType.L1,
"taxicab": DistanceType.L1,
"minkowski": DistanceType.LpUnexpanded,
"chebyshev": DistanceType.Linf,
"linf": DistanceType.Linf,
"cosine": DistanceType.CosineExpanded,
"correlation": DistanceType.CorrelationExpanded,
"hellinger": DistanceType.HellingerExpanded,
"hamming": DistanceType.HammingUnexpanded,
"jaccard": DistanceType.JaccardExpanded,
"canberra": DistanceType.Canberra
}


DENSE_SUPPORTED_METRICS = [
DistanceType.Canberra,
DistanceType.CorrelationExpanded,
DistanceType.CosineExpanded,
DistanceType.HammingUnexpanded,
DistanceType.HellingerExpanded,
# DistanceType.JaccardExpanded, # not supported
DistanceType.L1,
DistanceType.L2SqrtUnexpanded,
DistanceType.L2Unexpanded,
DistanceType.Linf,
DistanceType.LpUnexpanded,
]


SPARSE_SUPPORTED_METRICS = [
DistanceType.Canberra,
DistanceType.CorrelationExpanded,
DistanceType.CosineExpanded,
DistanceType.HammingUnexpanded,
DistanceType.HellingerExpanded,
DistanceType.JaccardExpanded,
DistanceType.L1,
DistanceType.L2SqrtUnexpanded,
DistanceType.L2Unexpanded,
DistanceType.Linf,
DistanceType.LpUnexpanded,
]
Loading

0 comments on commit 6bf61ca

Please sign in to comment.