Skip to content

Commit

Permalink
Kernel density estimation (#4545)
Browse files Browse the repository at this point in the history
Using a brute force approach compared to sklearn's kd/ball tree. 

Todo:
- [x] Implement sample method
- [x] Sample weights
- [x] Evaluate which metrics are missing
- [x] Tests for sample
- [x] Docstrings

Authors:
  - Rory Mitchell (https://github.com/RAMitchell)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Micka (https://github.com/lowener)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #4545
  • Loading branch information
RAMitchell authored Mar 8, 2022
1 parent 7c3da85 commit b430b2e
Show file tree
Hide file tree
Showing 10 changed files with 674 additions and 36 deletions.
17 changes: 17 additions & 0 deletions docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,10 @@ Metrics (regression, classification, and distance)
.. automodule:: cuml.metrics.pairwise_distances
:members:

.. automodule:: cuml.metrics.pairwise_kernels
:members:


Metrics (clustering and manifold learning)
------------------------------------------
.. automodule:: cuml.metrics.trustworthiness
Expand Down Expand Up @@ -335,6 +339,13 @@ Nearest Neighbors Regression
:members:
:noindex:

Kernel Ridge Regression
-----------------------

.. autoclass:: cuml.KernelRidge
:members:


Clustering
==========

Expand Down Expand Up @@ -429,6 +440,12 @@ Nearest Neighbors Regression
.. autoclass:: cuml.neighbors.KNeighborsRegressor
:members:

Kernel Density Estimation
--------------------------------

.. autoclass:: cuml.neighbors.KernelDensity
:members:

Time Series
============

Expand Down
1 change: 1 addition & 0 deletions python/cuml/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from cuml.naive_bayes.naive_bayes import MultinomialNB

from cuml.neighbors.nearest_neighbors import NearestNeighbors
from cuml.neighbors.kernel_density import KernelDensity
from cuml.neighbors.kneighbors_classifier import KNeighborsClassifier
from cuml.neighbors.kneighbors_regressor import KNeighborsRegressor

Expand Down
33 changes: 18 additions & 15 deletions python/cuml/kernel_ridge/kernel_ridge.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ class KernelRidge(Base, RegressorMixin):
in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed".
If `kernel` is "precomputed", X is assumed to be a kernel matrix.
`kernel` may be a callable numba device function. If so, is called on
each pair of instances (rows) and the resulting value recorded. The
callable should take two rows from X as input and return the
corresponding kernel value as a single number.
each pair of instances (rows) and the resulting value recorded.
gamma : float, default=None
Gamma parameter for the RBF, laplacian, polynomial, exponential chi2
and sigmoid kernels. Interpretation of the default value is left to
Expand Down Expand Up @@ -149,8 +147,10 @@ class KernelRidge(Base, RegressorMixin):
verbose : int or boolean, default=False
Sets logging level. It must be one of `cuml.common.logger.level_*`.
See :ref:`verbosity-levels` for more info.
Attributes
----------
dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets)
Representation of weight vector(s) in kernel space
X_fit_ : ndarray of shape (n_samples, n_features)
Expand Down Expand Up @@ -271,18 +271,21 @@ class KernelRidge(Base, RegressorMixin):
return self

def predict(self, X):
"""Predict using the kernel ridge model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples. If kernel == "precomputed" this is instead a
precomputed kernel matrix, shape = [n_samples,
n_samples_fitted], where n_samples_fitted is the number of
samples used in the fitting for this estimator.
Returns
-------
C : array of shape (n_samples,) or (n_samples, n_targets)
Returns predicted values.
"""
Predict using the kernel ridge model.
Parameters
----------
X : array-like of shape (n_samples, n_features)
Samples. If kernel == "precomputed" this is instead a
precomputed kernel matrix, shape = [n_samples,
n_samples_fitted], where n_samples_fitted is the number of
samples used in the fitting for this estimator.
Returns
-------
C : array of shape (n_samples,) or (n_samples, n_targets)
Returns predicted values.
"""
X_m, _, _, _ = input_to_cuml_array(
X, check_dtype=[np.float32, np.float64])
Expand Down
12 changes: 9 additions & 3 deletions python/cuml/metrics/pairwise_kernels.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np
import cuml.internals
from cuml.metrics import pairwise_distances
from cuml.common.input_utils import input_to_cupy_array


def linear_kernel(X, Y):
Expand Down Expand Up @@ -191,7 +192,8 @@ def evaluate_pairwise_kernels(X, Y, K):
@cuml.internals.api_return_array(get_output_type=True)
def pairwise_kernels(X, Y=None, metric="linear", *,
filter_params=False, convert_dtype=True, **kwds):
"""Compute the kernel between arrays X and optional array Y.
"""
Compute the kernel between arrays X and optional array Y.
This method takes either a vector array or a kernel matrix, and returns
a kernel matrix. If the input is a vector array, the kernels are
computed. If the input is a kernel matrix, it is returned instead.
Expand All @@ -203,6 +205,7 @@ def pairwise_kernels(X, Y=None, metric="linear", *,
Valid values for metric are:
['additive_chi2', 'chi2', 'linear', 'poly', 'polynomial', 'rbf',
'laplacian', 'sigmoid', 'cosine']
Parameters
----------
X : Dense matrix (device or host) of shape (n_samples_X, n_samples_X) or \
Expand Down Expand Up @@ -233,6 +236,7 @@ def pairwise_kernels(X, Y=None, metric="linear", *,
will increase memory used for the method.
**kwds : optional keyword parameters
Any further parameters are passed directly to the kernel function.
Returns
-------
K : ndarray of shape (n_samples_X, n_samples_X) or \
Expand All @@ -241,6 +245,7 @@ def pairwise_kernels(X, Y=None, metric="linear", *,
ith and jth vectors of the given matrix X, if Y is None.
If Y is not None, then K_{i, j} is the kernel between the ith array
from X and the jth array from Y.
Notes
-----
If metric is 'precomputed', Y is ignored and X is returned.
Expand Down Expand Up @@ -272,11 +277,11 @@ def custom_rbf_kernel(x, y, gamma=None):
pairwise_kernels(X, Y, metric=custom_rbf_kernel)
"""
X = cp.asarray(X)
X = input_to_cupy_array(X).array
if Y is None:
Y = X
else:
Y = cp.asarray(Y)
Y = input_to_cupy_array(Y).array
if X.shape[1] != Y.shape[1]:
raise ValueError("X and Y have different dimensions.")

Expand All @@ -292,4 +297,5 @@ def custom_rbf_kernel(x, y, gamma=None):
else:
kwds = _filter_params(
metric, filter_params, **kwds)

return custom_kernel(X, Y, metric, **kwds)
4 changes: 3 additions & 1 deletion python/cuml/neighbors/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#
# Copyright (c) 2019-2021, NVIDIA CORPORATION.
# Copyright (c) 2019-2022, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -20,6 +20,8 @@
from cuml.neighbors.nearest_neighbors import kneighbors_graph
from cuml.neighbors.kneighbors_classifier import KNeighborsClassifier
from cuml.neighbors.kneighbors_regressor import KNeighborsRegressor
from cuml.neighbors.kernel_density import (
KernelDensity, VALID_KERNELS, logsumexp_kernel)

VALID_METRICS = {
"brute": set([
Expand Down
Loading

0 comments on commit b430b2e

Please sign in to comment.