diff --git a/docs/source/api.rst b/docs/source/api.rst index c541c67ed6..2a4f566a52 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -165,6 +165,10 @@ Metrics (regression, classification, and distance) .. automodule:: cuml.metrics.pairwise_distances :members: + .. automodule:: cuml.metrics.pairwise_kernels + :members: + + Metrics (clustering and manifold learning) ------------------------------------------ .. automodule:: cuml.metrics.trustworthiness @@ -335,6 +339,13 @@ Nearest Neighbors Regression :members: :noindex: +Kernel Ridge Regression +----------------------- + +.. autoclass:: cuml.KernelRidge + :members: + + Clustering ========== @@ -429,6 +440,12 @@ Nearest Neighbors Regression .. autoclass:: cuml.neighbors.KNeighborsRegressor :members: +Kernel Density Estimation +-------------------------------- + +.. autoclass:: cuml.neighbors.KernelDensity + :members: + Time Series ============ diff --git a/python/cuml/__init__.py b/python/cuml/__init__.py index 1a3fd4cc54..7dea157c2b 100644 --- a/python/cuml/__init__.py +++ b/python/cuml/__init__.py @@ -65,6 +65,7 @@ from cuml.naive_bayes.naive_bayes import MultinomialNB from cuml.neighbors.nearest_neighbors import NearestNeighbors +from cuml.neighbors.kernel_density import KernelDensity from cuml.neighbors.kneighbors_classifier import KNeighborsClassifier from cuml.neighbors.kneighbors_regressor import KNeighborsRegressor diff --git a/python/cuml/kernel_ridge/kernel_ridge.pyx b/python/cuml/kernel_ridge/kernel_ridge.pyx index 4625096db8..447e52e2dd 100644 --- a/python/cuml/kernel_ridge/kernel_ridge.pyx +++ b/python/cuml/kernel_ridge/kernel_ridge.pyx @@ -118,9 +118,7 @@ class KernelRidge(Base, RegressorMixin): in `cuml.metrics.PAIRWISE_KERNEL_FUNCTIONS` or "precomputed". If `kernel` is "precomputed", X is assumed to be a kernel matrix. `kernel` may be a callable numba device function. If so, is called on - each pair of instances (rows) and the resulting value recorded. The - callable should take two rows from X as input and return the - corresponding kernel value as a single number. + each pair of instances (rows) and the resulting value recorded. gamma : float, default=None Gamma parameter for the RBF, laplacian, polynomial, exponential chi2 and sigmoid kernels. Interpretation of the default value is left to @@ -149,8 +147,10 @@ class KernelRidge(Base, RegressorMixin): verbose : int or boolean, default=False Sets logging level. It must be one of `cuml.common.logger.level_*`. See :ref:`verbosity-levels` for more info. + Attributes ---------- + dual_coef_ : ndarray of shape (n_samples,) or (n_samples, n_targets) Representation of weight vector(s) in kernel space X_fit_ : ndarray of shape (n_samples, n_features) @@ -271,18 +271,21 @@ class KernelRidge(Base, RegressorMixin): return self def predict(self, X): - """Predict using the kernel ridge model. - Parameters - ---------- - X : array-like of shape (n_samples, n_features) - Samples. If kernel == "precomputed" this is instead a - precomputed kernel matrix, shape = [n_samples, - n_samples_fitted], where n_samples_fitted is the number of - samples used in the fitting for this estimator. - Returns - ------- - C : array of shape (n_samples,) or (n_samples, n_targets) - Returns predicted values. + """ + Predict using the kernel ridge model. + + Parameters + ---------- + X : array-like of shape (n_samples, n_features) + Samples. If kernel == "precomputed" this is instead a + precomputed kernel matrix, shape = [n_samples, + n_samples_fitted], where n_samples_fitted is the number of + samples used in the fitting for this estimator. + + Returns + ------- + C : array of shape (n_samples,) or (n_samples, n_targets) + Returns predicted values. """ X_m, _, _, _ = input_to_cuml_array( X, check_dtype=[np.float32, np.float64]) diff --git a/python/cuml/metrics/pairwise_kernels.py b/python/cuml/metrics/pairwise_kernels.py index ef2d5e2e41..5416ce36ce 100644 --- a/python/cuml/metrics/pairwise_kernels.py +++ b/python/cuml/metrics/pairwise_kernels.py @@ -19,6 +19,7 @@ import numpy as np import cuml.internals from cuml.metrics import pairwise_distances +from cuml.common.input_utils import input_to_cupy_array def linear_kernel(X, Y): @@ -191,7 +192,8 @@ def evaluate_pairwise_kernels(X, Y, K): @cuml.internals.api_return_array(get_output_type=True) def pairwise_kernels(X, Y=None, metric="linear", *, filter_params=False, convert_dtype=True, **kwds): - """Compute the kernel between arrays X and optional array Y. + """ + Compute the kernel between arrays X and optional array Y. This method takes either a vector array or a kernel matrix, and returns a kernel matrix. If the input is a vector array, the kernels are computed. If the input is a kernel matrix, it is returned instead. @@ -203,6 +205,7 @@ def pairwise_kernels(X, Y=None, metric="linear", *, Valid values for metric are: ['additive_chi2', 'chi2', 'linear', 'poly', 'polynomial', 'rbf', 'laplacian', 'sigmoid', 'cosine'] + Parameters ---------- X : Dense matrix (device or host) of shape (n_samples_X, n_samples_X) or \ @@ -233,6 +236,7 @@ def pairwise_kernels(X, Y=None, metric="linear", *, will increase memory used for the method. **kwds : optional keyword parameters Any further parameters are passed directly to the kernel function. + Returns ------- K : ndarray of shape (n_samples_X, n_samples_X) or \ @@ -241,6 +245,7 @@ def pairwise_kernels(X, Y=None, metric="linear", *, ith and jth vectors of the given matrix X, if Y is None. If Y is not None, then K_{i, j} is the kernel between the ith array from X and the jth array from Y. + Notes ----- If metric is 'precomputed', Y is ignored and X is returned. @@ -272,11 +277,11 @@ def custom_rbf_kernel(x, y, gamma=None): pairwise_kernels(X, Y, metric=custom_rbf_kernel) """ - X = cp.asarray(X) + X = input_to_cupy_array(X).array if Y is None: Y = X else: - Y = cp.asarray(Y) + Y = input_to_cupy_array(Y).array if X.shape[1] != Y.shape[1]: raise ValueError("X and Y have different dimensions.") @@ -292,4 +297,5 @@ def custom_rbf_kernel(x, y, gamma=None): else: kwds = _filter_params( metric, filter_params, **kwds) + return custom_kernel(X, Y, metric, **kwds) diff --git a/python/cuml/neighbors/__init__.py b/python/cuml/neighbors/__init__.py index 2cece8b2f7..7362dceb0f 100644 --- a/python/cuml/neighbors/__init__.py +++ b/python/cuml/neighbors/__init__.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ from cuml.neighbors.nearest_neighbors import kneighbors_graph from cuml.neighbors.kneighbors_classifier import KNeighborsClassifier from cuml.neighbors.kneighbors_regressor import KNeighborsRegressor +from cuml.neighbors.kernel_density import ( + KernelDensity, VALID_KERNELS, logsumexp_kernel) VALID_METRICS = { "brute": set([ diff --git a/python/cuml/neighbors/kernel_density.py b/python/cuml/neighbors/kernel_density.py new file mode 100644 index 0000000000..aa6101887c --- /dev/null +++ b/python/cuml/neighbors/kernel_density.py @@ -0,0 +1,416 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import cupy as cp +import numpy as np +import math +from numba import cuda +from cuml.common.input_utils import input_to_cupy_array +from cuml.common.input_utils import input_to_cuml_array +from cuml.common.base import Base +from cuml.metrics import pairwise_distances +from cuml.common.import_utils import has_scipy +from cuml.common.exceptions import NotFittedError + +if has_scipy(): + from scipy.special import gammainc + + +VALID_KERNELS = [ + "gaussian", + "tophat", + "epanechnikov", + "exponential", + "linear", + "cosine", +] + + +@cp.fuse() +def gaussian_log_kernel(x, h): + return -(x * x) / (2 * h * h) + + +@cp.fuse() +def tophat_log_kernel(x, h): + """ + if x < h: + return 0.0 + else: + return -FLOAT_MIN + """ + y = (x >= h) * np.finfo(x.dtype).min + return y + + +@cp.fuse() +def epanechnikov_log_kernel(x, h): + # don't call log(0) otherwise we get NaNs + z = cp.maximum(1.0 - (x * x) / (h * h), 1e-30) + y = (x < h) * cp.log(z) + y += (x >= h) * np.finfo(y.dtype).min + return y + + +@cp.fuse() +def exponential_log_kernel(x, h): + return -x / h + + +@cp.fuse() +def linear_log_kernel(x, h): + # don't call log(0) otherwise we get NaNs + z = cp.maximum(1.0 - x / h, 1e-30) + y = (x < h) * cp.log(z) + y += (x >= h) * np.finfo(y.dtype).min + return y + + +@cp.fuse() +def cosine_log_kernel(x, h): + # don't call log(0) otherwise we get NaNs + z = cp.maximum(cp.cos(0.5 * np.pi * x / h), 1e-30) + y = (x < h) * cp.log(z) + y += (x >= h) * np.finfo(y.dtype).min + return y + + +log_probability_kernels_ = {"gaussian": gaussian_log_kernel, + "tophat": tophat_log_kernel, + "epanechnikov": epanechnikov_log_kernel, + "exponential": exponential_log_kernel, + "linear": linear_log_kernel, + "cosine": cosine_log_kernel} + + +def logVn(n): + return 0.5 * n * np.log(np.pi) - math.lgamma(0.5 * n + 1) + + +def logSn(n): + return np.log(2 * np.pi) + logVn(n - 1) + + +def norm_log_probabilities(log_probabilities, kernel, h, d): + if kernel == "gaussian": + factor = 0.5 * d * np.log(2 * np.pi) + elif kernel == "tophat": + factor = logVn(d) + elif kernel == "epanechnikov": + factor = logVn(d) + np.log(2.0 / (d + 2.0)) + elif kernel == "exponential": + factor = logSn(d - 1) + math.lgamma(d) + elif kernel == "linear": + factor = logVn(d) - np.log(d + 1.0) + elif kernel == "cosine": + factor = 0.0 + tmp = 2.0 / np.pi + for k in range(1, d + 1, 2): + factor += tmp + tmp *= -(d - k) * (d - k - 1) * (2.0 / np.pi) ** 2 + factor = np.log(factor) + logSn(d - 1) + else: + raise ValueError("Unsupported kernel.") + + return log_probabilities - (factor + d * np.log(h)) + + +@cuda.jit() +def logsumexp_kernel(distances, log_probabilities): + i = cuda.grid(1) + if i >= log_probabilities.size: + return + max_exp = distances[i, 0] + for j in range(1, distances.shape[1]): + if distances[i, j] > max_exp: + max_exp = distances[i, j] + sum = 0.0 + for j in range(0, distances.shape[1]): + sum += math.exp(distances[i, j] - max_exp) + log_probabilities[i] = math.log(sum) + max_exp + + +class KernelDensity(Base): + """ + Kernel Density Estimation. Computes a non-parametric density estimate + from a finite data sample, smoothing the estimate according to a + bandwidth parameter. + + Parameters + ---------- + bandwidth : float, default=1.0 + The bandwidth of the kernel. + kernel : {'gaussian', 'tophat', 'epanechnikov', 'exponential', 'linear', \ + 'cosine'}, default='gaussian' + The kernel to use. + metric : str, default='euclidean' + The distance metric to use. Note that not all metrics are + valid with all algorithms. Note that the normalization of the density + output is correct only for the Euclidean distance metric. Default + is 'euclidean'. + metric_params : dict, default=None + Additional parameters to be passed to the tree for use with the + metric. + output_type : {'input', 'cudf', 'cupy', 'numpy', 'numba'}, default=None + Variable to control output type of the results and attributes of + the estimator. If None, it'll inherit the output type set at the + module level, `cuml.global_settings.output_type`. + See :ref:`output-data-type-configuration` for more info. + handle : cuml.Handle + Specifies the cuml.handle that holds internal CUDA state for + computations in this model. Most importantly, this specifies the + CUDA stream that will be used for the model's computations, so + users can run different models concurrently in different streams + by creating handles in several streams. + If it is None, a new one is created. + verbose : int or boolean, default=False + Sets logging level. It must be one of `cuml.common.logger.level_*`. + See :ref:`verbosity-levels` for more info. + + Examples + -------- + + .. code-block:: python + + from cuml.neighbors import KernelDensity + import numpy as np + rng = np.random.RandomState(42) + X = rng.random_sample((100, 3)) + kde = KernelDensity(kernel='gaussian', bandwidth=0.5).fit(X) + log_density = kde.score_samples(X[:3]) + + """ + + def __init__( + self, + *, + bandwidth=1.0, + kernel="gaussian", + metric="euclidean", + metric_params=None, + output_type=None, + handle=None, + verbose=False + ): + super(KernelDensity, self).__init__( + verbose=verbose, handle=handle, output_type=output_type + ) + self.bandwidth = bandwidth + self.kernel = kernel + self.metric = metric + self.metric_params = metric_params + + if bandwidth <= 0: + raise ValueError("bandwidth must be positive") + if kernel not in VALID_KERNELS: + raise ValueError("invalid kernel: '{0}'".format(kernel)) + + def get_param_names(self): + return super().get_param_names() + [ + "bandwidth", + "kernel", + "metric", + "metric_params", + ] + + def fit(self, X, y=None, sample_weight=None): + """Fit the Kernel Density model on the data. + + Parameters + ---------- + + X : array-like of shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + y : None + Ignored. + sample_weight : array-like of shape (n_samples,), default=None + List of sample weights attached to the data X. + + Returns + ------- + + self : object + Returns the instance itself. + """ + if sample_weight is not None: + self.sample_weight_ = input_to_cupy_array( + sample_weight, check_dtype=[cp.float32, cp.float64] + ).array + if self.sample_weight_.min() <= 0: + raise ValueError("sample_weight must have positive values") + else: + self.sample_weight_ = None + + self.X_ = input_to_cupy_array( + X, order="C", check_dtype=[cp.float32, cp.float64] + ).array + + return self + + def score_samples(self, X): + """Compute the log-likelihood of each sample under the model. + + Parameters + ---------- + + X : array-like of shape (n_samples, n_features) + An array of points to query. Last dimension should match dimension + of training data (n_features). + + Returns + ------- + + density : ndarray of shape (n_samples,) + Log-likelihood of each sample in `X`. These are normalized to be + probability densities, so values will be low for high-dimensional + data. + """ + if not hasattr(self, "X_"): + raise NotFittedError() + X_cuml = input_to_cuml_array(X) + if self.metric_params: + if len(self.metric_params) != 1: + raise ValueError( + "Cuml only supports metrics with a single arg.") + metric_arg = list(self.metric_params.values())[0] + distances = pairwise_distances(X_cuml.array, self.X_, + metric=self.metric, + metric_arg=metric_arg) + else: + distances = pairwise_distances( + X_cuml.array, self.X_, metric=self.metric) + + distances = cp.asarray(distances) + + h = self.bandwidth + if self.kernel in log_probability_kernels_: + distances = log_probability_kernels_[self.kernel](distances, h) + else: + raise ValueError("Unsupported kernel.") + + log_probabilities = cp.zeros(distances.shape[0]) + if self.sample_weight_ is not None: + distances += cp.log(self.sample_weight_) + + logsumexp_kernel.forall(log_probabilities.size)( + distances, log_probabilities) + # Note that sklearns user guide is wrong + # It says the (unnormalised) probability output for + # the kernel density is sum(K(x,h)). + # In fact what they implment is (1/n)*sum(K(x,h)) + # Here we divide by n in normal probability space + # Which becomes -log(n) in log probability space + sum_weights = ( + cp.sum(self.sample_weight_) + if self.sample_weight_ is not None + else distances.shape[1] + ) + log_probabilities -= np.log(sum_weights) + + # norm + if len(X_cuml.array.shape) == 1: + # if X is one dimensional, we have 1 feature + dimension = 1 + else: + dimension = X_cuml.array.shape[1] + log_probabilities = norm_log_probabilities( + log_probabilities, self.kernel, h, dimension + ) + + return log_probabilities + + def score(self, X, y=None): + """Compute the total log-likelihood under the model. + + Parameters + ---------- + + X : array-like of shape (n_samples, n_features) + List of n_features-dimensional data points. Each row + corresponds to a single data point. + y : None + Ignored. + + Returns + ------- + + logprob : float + Total log-likelihood of the data in X. This is normalized to be a + probability density, so the value will be low for high-dimensional + data. + """ + return cp.sum(self.score_samples(X)) + + def sample(self, n_samples=1, random_state=None): + """ + Generate random samples from the model. + Currently, this is implemented only for gaussian and tophat kernels, + and the Euclidean metric. + + Parameters + ---------- + n_samples : int, default=1 + Number of samples to generate. + random_state : int, cupy RandomState instance or None, default=None + + Returns + ------- + X : cupy array of shape (n_samples, n_features) + List of samples. + """ + if not hasattr(self, "X_"): + raise NotFittedError() + + supported_kernels = ["gaussian", "tophat"] + if (self.kernel not in supported_kernels + or self.metric != "euclidian"): + raise NotImplementedError( + "Only {} kernels, and the euclidean" + " metric are supported.".format(supported_kernels)) + + if isinstance(random_state, cp.random.RandomState): + rng = random_state + else: + rng = cp.random.RandomState(random_state) + + u = rng.uniform(0, 1, size=n_samples) + if self.sample_weight_ is None: + i = (u * self.X_.shape[0]).astype(np.int64) + else: + cumsum_weight = cp.cumsum(self.sample_weight_) + sum_weight = cumsum_weight[-1] + i = cp.searchsorted(cumsum_weight, u * sum_weight) + if self.kernel == "gaussian": + return cp.atleast_2d(rng.normal(self.X_[i], self.bandwidth)) + + elif self.kernel == "tophat": + # we first draw points from a d-dimensional normal distribution, + # then use an incomplete gamma function to map them to a uniform + # d-dimensional tophat distribution. + has_scipy(raise_if_unavailable=True) + dim = self.X_.shape[1] + X = rng.normal(size=(n_samples, dim)) + s_sq = cp.einsum("ij,ij->i", X, X).get() + + # do this on the CPU becaause we don't have + # a gammainc function readily available + correction = cp.array( + gammainc(0.5 * dim, 0.5 * s_sq) ** (1.0 / dim) + * self.bandwidth + / np.sqrt(s_sq) + ) + return self.X_[i] + X * correction[:, np.newaxis] diff --git a/python/cuml/test/test_kernel_density.py b/python/cuml/test/test_kernel_density.py new file mode 100644 index 0000000000..75a2c82993 --- /dev/null +++ b/python/cuml/test/test_kernel_density.py @@ -0,0 +1,157 @@ +# +# Copyright (c) 2022, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from cuml.neighbors import KernelDensity, VALID_KERNELS, logsumexp_kernel +from cuml.common.exceptions import NotFittedError +from sklearn.metrics import pairwise_distances as skl_pairwise_distances +from sklearn.neighbors._ball_tree import kernel_norm +import numpy as np +from hypothesis import given, settings, assume, strategies as st +from hypothesis.extra.numpy import arrays +import pytest +from sklearn.model_selection import GridSearchCV +from cuml.test.utils import as_type + + +# not in log probability space +def compute_kernel_naive(Y, X, kernel, metric, h, sample_weight): + d = skl_pairwise_distances(Y, X, metric) + norm = kernel_norm(h, X.shape[1], kernel) + + if kernel == "gaussian": + k = np.exp(-0.5 * (d * d) / (h * h)) + elif kernel == "tophat": + k = (d < h) + elif kernel == "epanechnikov": + k = ((1.0 - (d * d) / (h * h)) * (d < h)) + elif kernel == "exponential": + k = (np.exp(-d / h)) + elif kernel == "linear": + k = ((1 - d / h) * (d < h)) + elif kernel == "cosine": + k = (np.cos(0.5 * np.pi * d / h) * (d < h)) + else: + raise ValueError("kernel not recognized") + return norm*np.average(k, -1, sample_weight) + + +@st.composite +def array_strategy(draw): + n = draw(st.integers(1, 100)) + m = draw(st.integers(1, 100)) + dtype = draw(st.sampled_from([np.float64, np.float32])) + rng = np.random.RandomState(34) + X = rng.randn(n, m).astype(dtype) + n_test = draw(st.integers(1, 100)) + X_test = rng.randn(n_test, m).astype(dtype) + + if draw(st.booleans()): + sample_weight = None + else: + sample_weight = draw(arrays(dtype=np.float64, shape=n, + elements=st.floats(0.1, 2.0),)) + type = draw(st.sampled_from(['numpy', 'cupy', 'cudf', 'pandas'])) + if type == 'cupy': + assume(n > 1 and n_test > 1) + return as_type(type, X, X_test, sample_weight) + + +metrics_strategy = st.sampled_from( + ['euclidean', 'manhattan', + 'chebyshev', 'minkowski', 'hamming', 'canberra']) + + +@settings(deadline=None) +@given(array_strategy(), st.sampled_from(VALID_KERNELS), + metrics_strategy, st.floats(0.2, 10)) +def test_kernel_density(arrays, kernel, metric, bandwidth): + X, X_test, sample_weight = arrays + X_np, X_test_np, sample_weight_np = as_type("numpy", *arrays) + + if kernel == 'cosine': + # cosine is numerically unstable at high dimensions + # for both cuml and sklearn + assume(X.shape[1] <= 20) + kde = KernelDensity(kernel=kernel, metric=metric, + bandwidth=bandwidth).fit(X, + sample_weight=sample_weight) + cuml_prob = kde.score_samples(X) + cuml_prob_test = kde.score_samples(X_test) + + if X_np.dtype == np.float64: + ref_prob = compute_kernel_naive( + X_np, X_np, kernel, metric, bandwidth, sample_weight_np) + ref_prob_test = compute_kernel_naive( + X_test_np, X_np, kernel, metric, bandwidth, sample_weight_np) + tol = 1e-3 + assert np.allclose(np.exp(as_type("numpy", cuml_prob)), ref_prob, + rtol=tol, atol=tol, equal_nan=True) + assert np.allclose(np.exp(as_type("numpy", cuml_prob_test)), + ref_prob_test, rtol=tol, atol=tol, equal_nan=True) + + if kernel in ["gaussian", "tophat"] and metric == "euclidian": + sample = kde.sample(100, random_state=32).get() + nearest = skl_pairwise_distances(sample, X, metric=metric) + nearest = nearest.min(axis=1) + if kernel == "gaussian": + assert np.all(nearest < 5 * bandwidth) + elif kernel == "tophat": + assert np.all(nearest <= bandwidth) + else: + with pytest.raises(NotImplementedError, + match=r"Only \['gaussian', 'tophat'\] kernels," + " and the euclidean metric are supported."): + kde.sample(100) + + +def test_logaddexp(): + X = np.array([[0.0, 0.0], [0.0, 0.0]]) + out = np.zeros(X.shape[0]) + logsumexp_kernel.forall(out.size)(X, out) + assert np.allclose(out, np.logaddexp.reduce(X, axis=1)) + + X = np.array([[3.0, 1.0], [0.2, 0.7]]) + logsumexp_kernel.forall(out.size)(X, out) + assert np.allclose(out, np.logaddexp.reduce(X, axis=1)) + + +def test_metric_params(): + X = np.array([[0.0, 1.0], [2.0, 0.5]]) + kde = KernelDensity(metric='minkowski', metric_params={'p': 1.0} + ).fit(X) + kde2 = KernelDensity(metric='minkowski', metric_params={'p': 2.0} + ).fit(X) + assert not np.allclose(kde.score_samples(X), kde2.score_samples(X)) + + +def test_grid_search(): + rs = np.random.RandomState(3) + X = rs.normal(size=(30, 5)) + params = {"bandwidth": np.logspace(-1, 1, 20)} + grid = GridSearchCV(KernelDensity(), params) + grid.fit(X) + + +def test_not_fitted(): + rs = np.random.RandomState(3) + kde = KernelDensity() + X = rs.normal(size=(30, 5)) + with pytest.raises(NotFittedError): + kde.score(X) + with pytest.raises(NotFittedError): + kde.sample(X) + with pytest.raises(NotFittedError): + kde.score_samples(X) diff --git a/python/cuml/test/test_kernel_ridge.py b/python/cuml/test/test_kernel_ridge.py index b9d08bac55..ae24e04a96 100644 --- a/python/cuml/test/test_kernel_ridge.py +++ b/python/cuml/test/test_kernel_ridge.py @@ -25,9 +25,10 @@ from sklearn.kernel_ridge import KernelRidge as sklKernelRidge from hypothesis import given, settings, assume, strategies as st from hypothesis.extra.numpy import arrays +from cuml.test.utils import as_type -def gradient_norm(X, y, model, K, sw=None): +def gradient_norm(model, X, y, K, sw=None): if sw is None: sw = cp.ones(X.shape[0]) else: @@ -36,7 +37,8 @@ def gradient_norm(X, y, model, K, sw=None): X = cp.array(X, dtype=np.float64) y = cp.array(y, dtype=np.float64) K = cp.array(K, dtype=np.float64) - betas = cp.array(model.dual_coef_, dtype=np.float64).reshape(y.shape) + betas = cp.array(as_type('cupy', model.dual_coef_), + dtype=np.float64).reshape(y.shape) # initialise to NaN in case below loop has 0 iterations grads = cp.full_like(y, np.NAN) @@ -162,7 +164,13 @@ def array_strategy(draw): ) else: Y = None - return (X, Y) + type = draw(st.sampled_from(['numpy', 'cupy', 'cudf', 'pandas'])) + + if type == 'cudf': + assume(X_m > 1) + if Y is not None: + assume(Y_m > 1) + return as_type(type, X, Y) @given(kernel_arg_strategy(), array_strategy()) @@ -172,8 +180,9 @@ def test_pairwise_kernels(kernel_arg, XY): kernel, args = kernel_arg K = pairwise_kernels(X, Y, metric=kernel, **args) skl_kernel = kernel.py_func if hasattr(kernel, "py_func") else kernel - K_sklearn = skl_pairwise_kernels(X, Y, metric=skl_kernel, **args) - assert np.allclose(K, K_sklearn, atol=0.01, rtol=0.01) + K_sklearn = skl_pairwise_kernels( + *as_type('numpy', X, Y), metric=skl_kernel, **args) + assert np.allclose(as_type('numpy', K), K_sklearn, atol=0.01, rtol=0.01) @st.composite @@ -207,7 +216,8 @@ def estimator_array_strategy(draw): ] ) ) - return (X, y, X_test, alpha, sample_weight) + type = draw(st.sampled_from(['numpy', 'cupy', 'cudf', 'pandas'])) + return (*as_type(type, X, y, X_test, alpha, sample_weight), dtype) @given( @@ -220,7 +230,7 @@ def estimator_array_strategy(draw): @settings(deadline=None) def test_estimator(kernel_arg, arrays, gamma, degree, coef0): kernel, args = kernel_arg - X, y, X_test, alpha, sample_weight = arrays + X, y, X_test, alpha, sample_weight, dtype = arrays model = cuKernelRidge( kernel=kernel, alpha=alpha, @@ -232,7 +242,7 @@ def test_estimator(kernel_arg, arrays, gamma, degree, coef0): skl_kernel = kernel.py_func if hasattr(kernel, "py_func") else kernel skl_model = sklKernelRidge( kernel=skl_kernel, - alpha=alpha, + alpha=as_type('numpy', alpha), gamma=gamma, degree=degree, coef0=coef0, @@ -240,25 +250,27 @@ def test_estimator(kernel_arg, arrays, gamma, degree, coef0): ) if kernel == "chi2" or kernel == "additive_chi2": # X must be positive - X = X + abs(X.min()) + 1.0 + X = (X - as_type('numpy', X).min()) + 1.0 model.fit(X, y, sample_weight) pred = model.predict(X_test).get() - if X.dtype == np.float64: + if dtype == np.float64: # For a convex optimisation problem we should arrive at gradient norm 0 # If the solution has converged correctly K = model._get_kernel(X) - grad_norm = gradient_norm(X, y, model, K, sample_weight) + grad_norm = gradient_norm( + model, *as_type('cupy', X, y, K, sample_weight)) assert grad_norm < 0.1 try: - skl_model.fit(X, y, sample_weight) + skl_model.fit(*as_type('numpy', X, y, sample_weight)) except np.linalg.LinAlgError: # sklearn can fail to fit multiclass models # with singular kernel matrices assume(False) - skl_pred = skl_model.predict(X_test) - assert np.allclose(pred, skl_pred, atol=1e-2, rtol=1e-2) + skl_pred = skl_model.predict(as_type('numpy', X_test)) + assert np.allclose(as_type('numpy', pred), + skl_pred, atol=1e-2, rtol=1e-2) def test_precomputed(): diff --git a/python/cuml/test/test_pickle.py b/python/cuml/test/test_pickle.py index 19d49e2d2f..fd24be3457 100644 --- a/python/cuml/test/test_pickle.py +++ b/python/cuml/test/test_pickle.py @@ -1,4 +1,4 @@ -# Copyright (c) 2019-2021, NVIDIA CORPORATION. +# Copyright (c) 2019-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -53,7 +53,8 @@ decomposition_config_xfail = ClassEnumerator(module=cuml.random_projection) decomposition_models_xfail = decomposition_config_xfail.get_models() -neighbor_config = ClassEnumerator(module=cuml.neighbors) +neighbor_config = ClassEnumerator(module=cuml.neighbors, exclude_classes=[ + cuml.neighbors.KernelDensity]) neighbor_models = neighbor_config.get_models() dbscan_model = {"DBSCAN": cuml.DBSCAN} @@ -68,7 +69,7 @@ rf_models = rf_module.get_models() k_neighbors_config = ClassEnumerator(module=cuml.neighbors, exclude_classes=[ - cuml.neighbors.NearestNeighbors]) + cuml.neighbors.NearestNeighbors, cuml.neighbors.KernelDensity]) k_neighbors_models = k_neighbors_config.get_models() unfit_pickle_xfail = [ diff --git a/python/cuml/test/utils.py b/python/cuml/test/utils.py index bf1ce21701..8d53ecc3d5 100644 --- a/python/cuml/test/utils.py +++ b/python/cuml/test/utils.py @@ -30,6 +30,7 @@ import cudf import cuml +from cuml.common.input_utils import input_to_cuml_array import pytest @@ -110,6 +111,28 @@ def normalize_clusters(a0, b0, n_clusters): return a, b +def as_type(type, *args): + # Convert array args to type supported by + # CumlArray.to_output ('numpy','cudf','cupy'...) + # Ensure 2 dimensional inputs are not converted to 1 dimension + # None remains as None + # Scalar remains a scalar + result = [] + for arg in args: + if arg is None or np.isscalar(arg): + result.append(arg) + else: + # make sure X with a single feature remains 2 dimensional + if type == 'cudf' and len(arg.shape) > 1: + result.append(input_to_cuml_array( + arg).array.to_output('dataframe')) + else: + result.append(input_to_cuml_array(arg).array.to_output(type)) + if len(result) == 1: + return result[0] + return tuple(result) + + def to_nparray(x): if isinstance(x, Number): return np.asarray([x])