diff --git a/python/cuml/common/array_descriptor.py b/python/cuml/common/array_descriptor.py index ba10578c96..f3737a4d0e 100644 --- a/python/cuml/common/array_descriptor.py +++ b/python/cuml/common/array_descriptor.py @@ -54,8 +54,14 @@ class CumlArrayDescriptor(): Python descriptor object to control getting/setting `CumlArray` attributes on `Base` objects. See the Estimator Guide for an in depth guide. """ + def __init__(self, order='K'): + # order corresponds to the order that the CumlArray attribute + # should be in to work with the C++ algorithms. + self.order = order + def __set_name__(self, owner, name): self.name = name + setattr(owner, name + '_order', self.order) def _get_meta(self, instance, diff --git a/python/cuml/common/input_utils.py b/python/cuml/common/input_utils.py index d450d3eb5b..d46d5a5903 100644 --- a/python/cuml/common/input_utils.py +++ b/python/cuml/common/input_utils.py @@ -535,7 +535,7 @@ def input_to_host_array(X, if isinstance(X, (int, float, complex, bool, str, type(None), dict, set, list, tuple)): - return X + return (X,) if isinstance(X, np.ndarray): if len(X.shape) > 1: diff --git a/python/cuml/common/mixins.py b/python/cuml/common/mixins.py index a12dfebb4e..a7ec72c841 100644 --- a/python/cuml/common/mixins.py +++ b/python/cuml/common/mixins.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2021, NVIDIA CORPORATION. +# Copyright (c) 2021-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -214,6 +214,33 @@ def score(self, X, y, **kwargs): preds = self.predict(X, **kwargs) return r2_score(y, preds, handle=handle) + # TODO : remove score function duplicate + # once updated CPU/GPU interoperability class is ready + @generate_docstring( + return_values={ + 'name': 'score', + 'type': 'float', + 'description': 'R^2 of self.predict(X) ' + 'wrt. y.' + }) + @cuml.internals.api_base_return_any_skipall + def _score(self, X, y, **kwargs): + """ + Scoring function for regression estimators + + Returns the coefficient of determination R^2 of the prediction. + + """ + from cuml.metrics.regression import r2_score + + if hasattr(self, 'handle'): + handle = self.handle + else: + handle = None + + preds = self._predict(X, **kwargs) + return r2_score(y, preds, handle=handle) + @staticmethod def _more_static_tags(): return { @@ -253,6 +280,33 @@ def score(self, X, y, **kwargs): preds = self.predict(X, **kwargs) return accuracy_score(y, preds, handle=handle) + # TODO : remove score function duplicate + # once updated CPU/GPU interoperability class is ready + @generate_docstring( + return_values={ + 'name': + 'score', + 'type': + 'float', + 'description': ('Accuracy of self.predict(X) wrt. y ' + '(fraction where y == pred_y)') + }) + @cuml.internals.api_base_return_any_skipall + def _score(self, X, y, **kwargs): + """ + Scoring function for classifier estimators based on mean accuracy. + + """ + from cuml.metrics.accuracy import accuracy_score + + if hasattr(self, 'handle'): + handle = self.handle + else: + handle = None + + preds = self._predict(X, **kwargs) + return accuracy_score(y, preds, handle=handle) + @staticmethod def _more_static_tags(): return { diff --git a/python/cuml/dask/common/base.py b/python/cuml/dask/common/base.py index c0d727ff16..3b58639283 100644 --- a/python/cuml/dask/common/base.py +++ b/python/cuml/dask/common/base.py @@ -23,7 +23,6 @@ from cuml.dask.common.utils import get_client from cuml.common.base import Base -from cuml.experimental.common.base import Base as experimentalBase from cuml.common.array import CumlArray from cuml.dask.common.utils import wait_and_raise_from_futures from raft_dask.common.comms import Comms @@ -129,7 +128,7 @@ def _check_internal_model(model): if model.type is None: wait_and_raise_from_futures([model]) - if not issubclass(model.type, (Base, experimentalBase)): + if not issubclass(model.type, Base): raise ValueError("Dask Future expected to contain cuml.Base " "but found %s instead." % model.type) diff --git a/python/cuml/decomposition/base_mg.pyx b/python/cuml/decomposition/base_mg.pyx index 3410d96668..9dea34c63f 100644 --- a/python/cuml/decomposition/base_mg.pyx +++ b/python/cuml/decomposition/base_mg.pyx @@ -67,9 +67,9 @@ class BaseDecompositionMG(object): self._set_n_features_in(n_cols) if self.n_components is None: - self._n_components = min(total_rows, n_cols) + self.n_components_ = min(total_rows, n_cols) else: - self._n_components = self.n_components + self.n_components_ = self.n_components X_arys = [] for i in range(len(X)): @@ -102,11 +102,11 @@ class BaseDecompositionMG(object): trans_arg = opg.build_data_t(trans_arys) trans_part_desc = opg.build_part_descriptor(total_rows, - self._n_components, + self.n_components_, rank_to_sizes, rank) - self._initialize_arrays(self._n_components, total_rows, n_cols) + self._initialize_arrays(self.n_components_, total_rows, n_cols) decomp_params = self._build_params(total_rows, n_cols) if _transform: diff --git a/python/cuml/decomposition/incremental_pca.py b/python/cuml/decomposition/incremental_pca.py index d4cd251e3a..f76babc4d0 100644 --- a/python/cuml/decomposition/incremental_pca.py +++ b/python/cuml/decomposition/incremental_pca.py @@ -349,7 +349,7 @@ def partial_fit(self, X, y=None, check_input=True) -> "IncrementalPCA": explained_variance = S ** 2 / (n_total_samples - 1) explained_variance_ratio = S ** 2 / cp.sum(col_var * n_total_samples) - self.n_rows = n_total_samples + self.n_samples_ = n_total_samples self.n_samples_seen_ = n_total_samples self.components_ = V[:self.n_components_] self.singular_values_ = S[:self.n_components_] diff --git a/python/cuml/decomposition/pca.pyx b/python/cuml/decomposition/pca.pyx index 6dbd13400d..c67ccc1bb0 100644 --- a/python/cuml/decomposition/pca.pyx +++ b/python/cuml/decomposition/pca.pyx @@ -33,7 +33,9 @@ from cython.operator cimport dereference as deref import cuml.internals from cuml.common.array import CumlArray -from cuml.common.base import Base +from cuml.experimental.common.base import Base +from cuml.common.mixins import FMajorInputTagMixin, \ + SparseInputTagMixin from cuml.common.doc_utils import generate_docstring from pylibraft.common.handle cimport handle_t from pylibraft.common.handle import Handle @@ -46,8 +48,7 @@ from cuml.common import using_output_type from cuml.prims.stats import cov from cuml.common.input_utils import sparse_scipy_to_cp from cuml.common.exceptions import NotFittedError -from cuml.common.mixins import FMajorInputTagMixin -from cuml.common.mixins import SparseInputTagMixin +from cuml.internals.api_decorators import device_interop_preparation cdef extern from "cuml/decomposition/pca.hpp" namespace "ML": @@ -270,14 +271,16 @@ class PCA(Base, `_. """ - components_ = CumlArrayDescriptor() - explained_variance_ = CumlArrayDescriptor() - explained_variance_ratio_ = CumlArrayDescriptor() - singular_values_ = CumlArrayDescriptor() - mean_ = CumlArrayDescriptor() - noise_variance_ = CumlArrayDescriptor() - trans_input_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.decomposition.PCA' + components_ = CumlArrayDescriptor(order='F') + explained_variance_ = CumlArrayDescriptor(order='F') + explained_variance_ratio_ = CumlArrayDescriptor(order='F') + singular_values_ = CumlArrayDescriptor(order='F') + mean_ = CumlArrayDescriptor(order='F') + noise_variance_ = CumlArrayDescriptor(order='F') + trans_input_ = CumlArrayDescriptor(order='F') + @device_interop_preparation def __init__(self, *, copy=True, handle=None, iterated_power=15, n_components=None, random_state=None, svd_solver='auto', tol=1e-7, verbose=False, whiten=False, @@ -325,7 +328,7 @@ class PCA(Base, def _build_params(self, n_rows, n_cols): cpdef paramsPCA *params = new paramsPCA() - params.n_components = self._n_components + params.n_components = self.n_components_ params.n_rows = n_rows params.n_cols = n_cols params.whiten = self.whiten @@ -354,8 +357,9 @@ class PCA(Base, self._sparse_model = True - self.n_rows = X.shape[0] - self.n_cols = X.shape[1] + self.n_samples_ = X.shape[0] + self.n_features_ = X.shape[1] if X.ndim == 2 else 1 + self.n_features_in_ = self.n_features_ self.dtype = X.dtype # NOTE: All intermediate calculations are done using cupy.ndarray and @@ -374,34 +378,34 @@ class PCA(Base, self.components_ = cp.flip(self.components_, axis=1) - self.components_ = self.components_.T[:self._n_components, :] + self.components_ = self.components_.T[:self.n_components_, :] self.explained_variance_ratio_ = self.explained_variance_ / cp.sum( self.explained_variance_) - if self._n_components < min(self.n_rows, self.n_cols): + if self.n_components_ < min(self.n_samples_, self.n_features_): self.noise_variance_ = \ - self.explained_variance_[self._n_components:].mean() + self.explained_variance_[self.n_components_:].mean() else: self.noise_variance_ = cp.array([0.0]) self.explained_variance_ = \ - self.explained_variance_[:self._n_components] + self.explained_variance_[:self.n_components_] self.explained_variance_ratio_ = \ - self.explained_variance_ratio_[:self._n_components] + self.explained_variance_ratio_[:self.n_components_] # Truncating negative explained variance values to 0 self.singular_values_ = \ cp.where(self.explained_variance_ < 0, 0, self.explained_variance_) self.singular_values_ = \ - cp.sqrt(self.singular_values_ * (self.n_rows - 1)) + cp.sqrt(self.singular_values_ * (self.n_samples_ - 1)) return self @generate_docstring(X='dense_sparse') - def fit(self, X, y=None) -> "PCA": + def _fit(self, X, y=None) -> "PCA": """ Fit the model with X. y is currently ignored. @@ -414,9 +418,9 @@ class PCA(Base, ) n_rows = X.shape[0] n_cols = X.shape[1] - self._n_components = min(n_rows, n_cols) + self.n_components_ = min(n_rows, n_cols) else: - self._n_components = self.n_components + self.n_components_ = self.n_components if cupyx.scipy.sparse.issparse(X): return self._sparse_fit(X) @@ -424,14 +428,16 @@ class PCA(Base, X = sparse_scipy_to_cp(X, dtype=None) return self._sparse_fit(X) - X_m, self.n_rows, self.n_cols, self.dtype = \ + X_m, self.n_samples_, self.n_features_, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) cdef uintptr_t input_ptr = X_m.ptr + self.n_features_in_ = self.n_features_ + self.feature_names_in_ = X_m.index cdef paramsPCA *params = \ - self._build_params(self.n_rows, self.n_cols) + self._build_params(self.n_samples_, self.n_features_) - if params.n_components > self.n_cols: + if params.n_components > self.n_features_: raise ValueError('Number of components should not be greater than' 'the number of columns in the data') @@ -489,7 +495,7 @@ class PCA(Base, 'description': 'Transformed values', 'shape': '(n_samples, n_components)'}) @cuml.internals.api_base_return_array_skipall - def fit_transform(self, X, y=None) -> CumlArray: + def _fit_transform(self, X, y=None) -> CumlArray: """ Fit the model with X and apply the dimensionality reduction on X. @@ -507,7 +513,8 @@ class PCA(Base, if self.whiten: cp.multiply(self.components_, - (1 / cp.sqrt(self.n_rows - 1)), out=self.components_) + (1 / cp.sqrt(self.n_samples_ - 1)), + out=self.components_) cp.multiply(self.components_, self.singular_values_.reshape((-1, 1)), out=self.components_) @@ -517,7 +524,7 @@ class PCA(Base, if self.whiten: self.components_ /= self.singular_values_.reshape((-1, 1)) - self.components_ *= cp.sqrt(self.n_rows - 1) + self.components_ *= cp.sqrt(self.n_samples_ - 1) if return_sparse: X_inv = cp.where(X_inv < sparse_tol, 0, X_inv) @@ -533,16 +540,17 @@ class PCA(Base, 'type': 'dense_sparse', 'description': 'Transformed values', 'shape': '(n_samples, n_features)'}) - def inverse_transform(self, X, convert_dtype=False, - return_sparse=False, sparse_tol=1e-10) -> CumlArray: + def _inverse_transform(self, X, convert_dtype=False, + return_sparse=False, sparse_tol=1e-10) -> CumlArray: """ Transform data back to its original space. In other words, return an input X_original whose transform would be X. """ - self._check_is_fitted('components_') + dtype = self.components_.dtype + if cupyx.scipy.sparse.issparse(X): return self._sparse_inverse_transform(X, return_sparse=return_sparse, @@ -561,8 +569,8 @@ class PCA(Base, sparse_tol=sparse_tol) X_m, n_rows, _, dtype = \ - input_to_cuml_array(X, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype + input_to_cuml_array(X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype else None) ) @@ -570,9 +578,9 @@ class PCA(Base, # todo: check n_cols and dtype cpdef paramsPCA params - params.n_components = self._n_components + params.n_components = self.n_components_ params.n_rows = n_rows - params.n_cols = self.n_cols + params.n_cols = self.n_features_ params.whiten = self.whiten input_data = CumlArray.zeros((params.n_rows, params.n_cols), @@ -616,7 +624,7 @@ class PCA(Base, with using_output_type("cupy"): if self.whiten: - self.components_ *= cp.sqrt(self.n_rows - 1) + self.components_ *= cp.sqrt(self.n_samples_ - 1) self.components_ /= self.singular_values_.reshape((-1, 1)) X = X - self.mean_ @@ -624,7 +632,7 @@ class PCA(Base, if self.whiten: self.components_ *= self.singular_values_.reshape((-1, 1)) - self.components_ *= (1 / cp.sqrt(self.n_rows - 1)) + self.components_ *= (1 / cp.sqrt(self.n_samples_ - 1)) return X_transformed @@ -633,7 +641,7 @@ class PCA(Base, 'type': 'dense_sparse', 'description': 'Transformed values', 'shape': '(n_samples, n_components)'}) - def transform(self, X, convert_dtype=False) -> CumlArray: + def _transform(self, X, convert_dtype=False) -> CumlArray: """ Apply dimensionality reduction to X. @@ -641,8 +649,9 @@ class PCA(Base, from a training set. """ - self._check_is_fitted('components_') + dtype = self.components_.dtype + if cupyx.scipy.sparse.issparse(X): return self._sparse_transform(X) elif scipy.sparse.issparse(X): @@ -655,16 +664,16 @@ class PCA(Base, return self._sparse_transform(X) X_m, n_rows, n_cols, dtype = \ - input_to_cuml_array(X, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype + input_to_cuml_array(X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype else None), - check_cols=self.n_cols) + check_cols=self.n_features_) cdef uintptr_t input_ptr = X_m.ptr # todo: check dtype cpdef paramsPCA params - params.n_components = self._n_components + params.n_components = self.n_components_ params.n_rows = n_rows params.n_cols = n_cols params.whiten = self.whiten @@ -713,3 +722,10 @@ class PCA(Base, msg = ("This instance is not fitted yet. Call 'fit' " "with appropriate arguments before using this estimator.") raise NotFittedError(msg) + + def get_attr_names(self): + return ['components_', 'explained_variance_', + 'explained_variance_ratio_', 'singular_values_', + 'mean_', 'n_components_', 'noise_variance_', + 'n_samples_', 'n_features_', 'n_features_in_', + 'feature_names_in_'] diff --git a/python/cuml/decomposition/pca_mg.pyx b/python/cuml/decomposition/pca_mg.pyx index 111cd77d3c..148f555875 100644 --- a/python/cuml/decomposition/pca_mg.pyx +++ b/python/cuml/decomposition/pca_mg.pyx @@ -90,13 +90,14 @@ class PCAMG(BaseDecompositionMG, PCA): def _build_params(self, n_rows, n_cols): cpdef paramsPCAMG *params = new paramsPCAMG() - params.n_components = self._n_components + params.n_components = self.n_components_ params.n_rows = n_rows params.n_cols = n_cols params.whiten = self.whiten params.tol = self.tol params.algorithm = ( ( self.c_algorithm)) + self.n_features_ = n_cols return params diff --git a/python/cuml/decomposition/tsvd.pyx b/python/cuml/decomposition/tsvd.pyx index 6831c4c28e..c7da0d6cc4 100644 --- a/python/cuml/decomposition/tsvd.pyx +++ b/python/cuml/decomposition/tsvd.pyx @@ -27,13 +27,14 @@ from libc.stdint cimport uintptr_t from cuml.common.array import CumlArray -from cuml.common.base import Base +from cuml.experimental.common.base import Base +from cuml.common.mixins import FMajorInputTagMixin from cuml.common.doc_utils import generate_docstring from pylibraft.common.handle cimport handle_t from cuml.decomposition.utils cimport * from cuml.common import input_to_cuml_array from cuml.common.array_descriptor import CumlArrayDescriptor -from cuml.common.mixins import FMajorInputTagMixin +from cuml.internals.api_decorators import device_interop_preparation from cython.operator cimport dereference as deref @@ -233,11 +234,13 @@ class TruncatedSVD(Base, """ - components_ = CumlArrayDescriptor() - explained_variance_ = CumlArrayDescriptor() - explained_variance_ratio_ = CumlArrayDescriptor() - singular_values_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.decomposition.TruncatedSVD' + components_ = CumlArrayDescriptor(order='F') + explained_variance_ = CumlArrayDescriptor(order='F') + explained_variance_ratio_ = CumlArrayDescriptor(order='F') + singular_values_ = CumlArrayDescriptor(order='F') + @device_interop_preparation def __init__(self, *, algorithm='full', handle=None, n_components=1, n_iter=15, random_state=None, tol=1e-7, verbose=False, output_type=None): @@ -295,7 +298,7 @@ class TruncatedSVD(Base, dtype=self.dtype) @generate_docstring() - def fit(self, X, y=None) -> "TruncatedSVD": + def _fit(self, X, y=None) -> "TruncatedSVD": """ Fit LSI model on training cudf DataFrame X. y is currently ignored. @@ -309,20 +312,21 @@ class TruncatedSVD(Base, 'type': 'dense', 'description': 'Reduced version of X', 'shape': '(n_samples, n_components)'}) - def fit_transform(self, X, y=None) -> CumlArray: + def _fit_transform(self, X, y=None) -> CumlArray: """ Fit LSI model to X and perform dimensionality reduction on X. y is currently ignored. """ - X_m, self.n_rows, self.n_cols, self.dtype = \ + X_m, self.n_rows, self.n_features_in_, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) cdef uintptr_t input_ptr = X_m.ptr cdef paramsTSVD *params = \ - self._build_params(self.n_rows, self.n_cols) + self._build_params(self.n_rows, self.n_features_in_) - self._initialize_arrays(self.n_components, self.n_rows, self.n_cols) + self._initialize_arrays(self.n_components, self.n_rows, + self.n_features_in_) cdef uintptr_t comp_ptr = self.components_.ptr @@ -339,7 +343,7 @@ class TruncatedSVD(Base, dtype=self.dtype, index=X_m.index) cdef uintptr_t t_input_ptr = _trans_input_.ptr - if self.n_components> self.n_cols: + if self.n_components> self.n_features_in_: raise ValueError(' n_components must be < n_features') cdef handle_t* handle_ = self.handle.getHandle() @@ -372,25 +376,25 @@ class TruncatedSVD(Base, 'type': 'dense', 'description': 'X in original space', 'shape': '(n_samples, n_features)'}) - def inverse_transform(self, X, convert_dtype=False) -> CumlArray: + def _inverse_transform(self, X, convert_dtype=False) -> CumlArray: """ Transform X back to its original space. Returns X_original whose transform would be X. """ - + dtype = self.components_.dtype X_m, n_rows, _, dtype = \ - input_to_cuml_array(X, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype + input_to_cuml_array(X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype else None)) cpdef paramsTSVD params params.n_components = self.n_components params.n_rows = n_rows - params.n_cols = self.n_cols + params.n_cols = self.n_features_in_ input_data = CumlArray.zeros((params.n_rows, params.n_cols), - dtype=self.dtype, index=X_m.index) + dtype=dtype, index=X_m.index) cdef uintptr_t trans_input_ptr = X_m.ptr cdef uintptr_t input_ptr = input_data.ptr @@ -421,25 +425,28 @@ class TruncatedSVD(Base, 'type': 'dense', 'description': 'Reduced version of X', 'shape': '(n_samples, n_components)'}) - def transform(self, X, convert_dtype=False) -> CumlArray: + def _transform(self, X, convert_dtype=False) -> CumlArray: """ Perform dimensionality reduction on X. """ + dtype = self.components_.dtype + self.n_features_in_ = self.components_.shape[1] + X_m, n_rows, _, dtype = \ - input_to_cuml_array(X, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype + input_to_cuml_array(X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype else None), - check_cols=self.n_cols) + check_cols=self.n_features_in_) cpdef paramsTSVD params params.n_components = self.n_components params.n_rows = n_rows - params.n_cols = self.n_cols + params.n_cols = self.n_features_in_ t_input_data = \ CumlArray.zeros((params.n_rows, params.n_components), - dtype=self.dtype, index=X_m.index) + dtype=dtype, index=X_m.index) cdef uintptr_t input_ptr = X_m.ptr cdef uintptr_t trans_input_ptr = t_input_data.ptr @@ -469,3 +476,8 @@ class TruncatedSVD(Base, def get_param_names(self): return super().get_param_names() + \ ["algorithm", "n_components", "n_iter", "random_state", "tol"] + + def get_attr_names(self): + return ['components_', 'explained_variance_', + 'explained_variance_ratio_', 'singular_values_', + 'n_features_in_', 'feature_names_in_'] diff --git a/python/cuml/decomposition/tsvd_mg.pyx b/python/cuml/decomposition/tsvd_mg.pyx index 45cf8dbd14..f19d2bb16c 100644 --- a/python/cuml/decomposition/tsvd_mg.pyx +++ b/python/cuml/decomposition/tsvd_mg.pyx @@ -73,7 +73,7 @@ class TSVDMG(BaseDecompositionMG, TruncatedSVD): def _build_params(self, n_rows, n_cols): cpdef paramsTSVDMG *params = new paramsTSVDMG() - params.n_components = self._n_components + params.n_components = self.n_components_ params.n_rows = n_rows params.n_cols = n_cols params.n_iterations = self.n_iter diff --git a/python/cuml/experimental/common/__init__.py b/python/cuml/experimental/common/__init__.py index 58f67b199b..2a4d63655e 100644 --- a/python/cuml/experimental/common/__init__.py +++ b/python/cuml/experimental/common/__init__.py @@ -14,5 +14,4 @@ # limitations under the License. # -from cuml.experimental.common.base import UniversalBase -from cuml.experimental.common.base import enable_cpu +from cuml.experimental.common.base import Base diff --git a/python/cuml/experimental/common/base.pyx b/python/cuml/experimental/common/base.pyx index d5d24226ae..d8266b8fea 100644 --- a/python/cuml/experimental/common/base.pyx +++ b/python/cuml/experimental/common/base.pyx @@ -16,31 +16,25 @@ # distutils: language = c++ -import functools -from importlib import import_module -import numpy as np +import typing +import numpy as np +import cupy as cp import cuml -from cuml.common.device_selection import DeviceType +import cuml.common.logger as logger from cuml.common.input_utils import input_to_cuml_array from cuml.common.input_utils import input_to_host_array -from cuml.common.base import Base - - -def enable_cpu(gpu_func): - @functools.wraps(gpu_func) - def dispatch(self, *args, **kwargs): - func_name = gpu_func.func_name - return self.dispatch_func(func_name, gpu_func, *args, **kwargs) - return dispatch +from cuml.common.array import CumlArray +from cuml.common.device_selection import DeviceType +from cuml.common.base import Base as originalBase -class UniversalBase(Base): +class Base(originalBase): """ Experimental base class to implement CPU/GPU interoperability. """ - def dispatch_func(self, func_name, gpu_func, *args, **kwargs): + def dispatch_func(self, func_name, *args, **kwargs): """ This function will dispatch calls to training and inference according to the global configuration. It should work for all estimators @@ -51,8 +45,6 @@ class UniversalBase(Base): ---------- func_name : string name of the function to be dispatched - gpu_func : function - original cuML function args : arguments arguments to be passed to the function for the call kwargs : keyword arguments @@ -62,35 +54,41 @@ class UniversalBase(Base): device_type = cuml.global_settings.device_type if device_type == DeviceType.device: # call the original cuml method - return gpu_func(self, *args, **kwargs) + cuml_func_name = '_' + func_name + if hasattr(self, cuml_func_name): + cuml_func = getattr(self, cuml_func_name) + return cuml_func(*args, **kwargs) + else: + raise ValueError('Function "{}" could not be found in' + ' the cuML estimator'.format(cuml_func_name)) + elif device_type == DeviceType.host: # check if the sklean model already set as attribute of the cuml # estimator its presence should signify that CPU execution was # used previously - if not hasattr(self, 'sk_model_'): - # import model in sklearn - if hasattr(self, 'sk_import_path_'): - # if import path differs from the one of sklearn - # look for sk_import_path_ - model_path = self.sk_import_path_ - else: - # import from similar path to the current estimator - # class - model_path = 'sklearn' + self.__class__.__module__[4:] - model_name = self.__class__.__name__ - sk_model = getattr(import_module(model_path), model_name) + if not hasattr(self, '_cpu_model'): + filtered_kwargs = {} + for keyword, arg in self._full_kwargs.items(): + if keyword in self._cpu_hyperparams: + filtered_kwargs[keyword] = arg + else: + logger.info("Unused keyword parameter: {} " + "during CPU estimator " + "initialization".format(keyword)) + # initialize model - self.sk_model_ = sk_model() - # transfer params set during cuml estimator initialization - for param in self.get_param_names(): - self.sk_model_.__dict__[param] = self.__dict__[param] + self._cpu_model = self._cpu_model_class(**filtered_kwargs) # transfer attributes trained with cuml - for attr in self.get_attributes_names(): + for attr in self.get_attr_names(): # check presence of attribute - if hasattr(self, attr): + if hasattr(self, attr) or \ + isinstance(getattr(type(self), attr, None), property): # get the cuml attribute - cu_attr = self.__dict__[attr] + if hasattr(self, attr): + cu_attr = getattr(self, attr) + else: + cu_attr = getattr(type(self), attr).fget(self) # if the cuml attribute is a CumlArrayDescriptorMeta if hasattr(cu_attr, 'get_input_value'): # extract the actual value from the @@ -100,20 +98,27 @@ class UniversalBase(Base): if cu_attr_value is not None: if cu_attr.input_type == 'cuml': # transform cumlArray to numpy and set it - # as an attribute in the sklearn model - self.sk_model_.__dict__[attr] = \ - cu_attr_value.to_output('numpy') + # as an attribute in the CPU estimator + setattr(self._cpu_model, attr, + cu_attr_value.to_output('numpy')) else: # transfer all other types of attributes # directly - self.sk_model_.__dict__[attr] = \ - cu_attr_value + setattr(self._cpu_model, attr, + cu_attr_value) + elif isinstance(cu_attr, CumlArray): + # transform cumlArray to numpy and set it + # as an attribute in the CPU estimator + setattr(self._cpu_model, attr, + cu_attr.to_output('numpy')) + elif isinstance(cu_attr, cp.ndarray): + # transform cupy to numpy and set it + # as an attribute in the CPU estimator + setattr(self._cpu_model, attr, + cp.asnumpy(cu_attr)) else: # transfer all other types of attributes directly - self.sk_model_.__dict__[attr] = cu_attr - else: - raise ValueError('Attribute "{}" could not be found in' - ' the cuML estimator'.format(attr)) + setattr(self._cpu_model, attr, cu_attr) # converts all the args args = tuple(input_to_host_array(arg)[0] for arg in args) @@ -122,25 +127,75 @@ class UniversalBase(Base): kwargs[key] = input_to_host_array(kwarg)[0] # call the method from the sklearn model - sk_func = getattr(self.sk_model_, func_name) - res = sk_func(*args, **kwargs) - if func_name == 'fit': + cpu_func = getattr(self._cpu_model, func_name) + res = cpu_func(*args, **kwargs) + + if func_name in ['fit', 'fit_transform', 'fit_predict']: # need to do this to mirror input type self._set_output_type(args[0]) # always return the cuml estimator while training # mirror sk attributes to cuml after training - for attribute in self.get_attributes_names(): - sk_attr = self.sk_model_.__dict__[attribute] - # if the sklearn attribute is an array - if isinstance(sk_attr, np.ndarray): - # transfer array to gpu and set it as a cuml - # attribute - cuml_array = input_to_cuml_array(sk_attr)[0] - setattr(self, attribute, cuml_array) - else: - # transfer all other types of attributes directly - setattr(self, attribute, sk_attr) - return self - else: - # return method result - return res + for attr in self.get_attr_names(): + # check presence of attribute + if hasattr(self._cpu_model, attr) or \ + isinstance(getattr(type(self._cpu_model), + attr, None), property): + # get the cpu attribute + if hasattr(self._cpu_model, attr): + cpu_attr = getattr(self._cpu_model, attr) + else: + cpu_attr = getattr(type(self._cpu_model), + attr).fget(self._cpu_model) + # if the cpu attribute is an array + if isinstance(cpu_attr, np.ndarray): + # get data order wished for by CumlArrayDescriptor + if hasattr(self, attr + '_order'): + order = getattr(self, attr + '_order') + else: + order = 'K' + # transfer array to gpu and set it as a cuml + # attribute + cuml_array = input_to_cuml_array(cpu_attr, + order=order)[0] + setattr(self, attr, cuml_array) + else: + # transfer all other types of attributes directly + setattr(self, attr, cpu_attr) + if func_name == 'fit': + return self + # return method result + return res + + def fit(self, *args, **kwargs): + return self.dispatch_func('fit', *args, **kwargs) + + def predict(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('predict', *args, **kwargs) + + def transform(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('transform', *args, **kwargs) + + def kneighbors(self, X, *args, **kwargs) \ + -> typing.Union[CumlArray, typing.Tuple[CumlArray, CumlArray]]: + return self.dispatch_func('kneighbors', X, *args, **kwargs) + + def fit_transform(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('fit_transform', *args, **kwargs) + + def fit_predict(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('fit_predict', *args, **kwargs) + + def inverse_transform(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('inverse_transform', *args, **kwargs) + + def score(self, *args, **kwargs): + return self.dispatch_func('score', *args, **kwargs) + + def decision_function(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('decision_function', *args, **kwargs) + + def predict_proba(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('predict_proba', *args, **kwargs) + + def predict_log_proba(self, *args, **kwargs) -> CumlArray: + return self.dispatch_func('predict_log_proba', *args, **kwargs) diff --git a/python/cuml/explainer/common.py b/python/cuml/explainer/common.py index 049d84b751..df33a9b11b 100644 --- a/python/cuml/explainer/common.py +++ b/python/cuml/explainer/common.py @@ -19,7 +19,6 @@ from cuml.common.input_utils import input_to_cupy_array from cuml.common.base import Base -from cuml.experimental.common.base import Base as experimentalBase def get_tag_from_model_func(func, tag, default=None): @@ -67,7 +66,7 @@ def get_handle_from_cuml_model_func(func, create_new=False): """ owner = getattr(func, '__self__', None) - if owner is not None and isinstance(owner, (Base, experimentalBase)): + if owner is not None and isinstance(owner, Base): if owner.handle is not None: return owner.handle diff --git a/python/cuml/internals/api_decorators.py b/python/cuml/internals/api_decorators.py index 8c02a19440..f738b35bff 100644 --- a/python/cuml/internals/api_decorators.py +++ b/python/cuml/internals/api_decorators.py @@ -1,5 +1,5 @@ # -# Copyright (c) 2020-2021, NVIDIA CORPORATION. +# Copyright (c) 2020-2022, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import typing from functools import wraps import warnings +from importlib import import_module import cuml import cuml.common @@ -39,6 +40,7 @@ from cuml.internals.api_context_managers import set_api_output_dtype from cuml.internals.api_context_managers import set_api_output_type from cuml.internals.base_helpers import _get_base_return_type +from cuml.common import logger CUML_WRAPPED_FLAG = "__cuml_is_wrapped" @@ -797,3 +799,52 @@ def inner_f(*args, **kwargs): inner_f.__dict__[_deprecate_pos_args.FLAG_NAME] = True return inner_f + + +def device_interop_preparation(init_func): + """ + This function serves as a decorator to cuML estimators that implement + the CPU/GPU interoperability feature. It imports the joint CPU estimator + and processes the hyperparameters. + """ + + @functools.wraps(init_func) + def processor(self, *args, **kwargs): + # if child class (parent class was already decorated), skip + if hasattr(self, '_cpu_model_class'): + return init_func(self, *args, **kwargs) + + if hasattr(self, '_cpu_estimator_import_path'): + # if import path differs from the one of sklearn + # look for _cpu_estimator_import_path + estimator_path = self._cpu_estimator_import_path.split('.') + model_path = '.'.join(estimator_path[:-1]) + model_name = estimator_path[-1] + else: + # import from similar path to the current estimator + # class + model_path = 'sklearn' + self.__class__.__module__[4:] + model_name = self.__class__.__name__ + self._cpu_model_class = getattr(import_module(model_path), model_name) + + # Save all kwargs + self._full_kwargs = kwargs + # Generate list of available cuML hyperparameters + gpu_hyperparams = list(inspect.signature(init_func).parameters.keys()) + # Save list of available CPU estimator hyperparameters + self._cpu_hyperparams = list( + inspect.signature(self._cpu_model_class.__init__).parameters.keys() + ) + + # Filter provided parameters for cuML estimator initialization + filtered_kwargs = {} + for keyword, arg in self._full_kwargs.items(): + if keyword in gpu_hyperparams: + filtered_kwargs[keyword] = arg + else: + logger.info("Unused keyword parameter: {} " + "during cuML estimator " + "initialization".format(keyword)) + + return init_func(self, *args, **filtered_kwargs) + return processor diff --git a/python/cuml/linear_model/base.pyx b/python/cuml/linear_model/base.pyx index 4e43a2ee29..6e86564f28 100644 --- a/python/cuml/linear_model/base.pyx +++ b/python/cuml/linear_model/base.pyx @@ -63,24 +63,27 @@ class LinearPredictMixin: 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) @cuml.internals.api_base_return_array_skipall - def predict(self, X, convert_dtype=True) -> CumlArray: + def _predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts `y` values for `X`. """ + self.dtype = self.coef_.dtype + if self.coef_ is None: raise ValueError( "LinearModel.predict() cannot be called before fit(). " "Please fit the model first." ) - if len(self.coef_.shape) == 2 and self.coef_.shape[1] > 1: + n_targets = self.coef_.shape[0] + if len(self.coef_.shape) == 2 and n_targets > 1: # Handle multi-target prediction in Python. coef_cp = input_to_cupy_array(self.coef_).array X_cp = input_to_cupy_array( X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), - check_cols=self.n_cols + check_cols=self.n_features_in_ ).array intercept_cp = input_to_cupy_array(self.intercept_).array preds_cp = X_cp @ coef_cp + intercept_cp @@ -92,7 +95,7 @@ class LinearPredictMixin: input_to_cuml_array(X, check_dtype=self.dtype, convert_to_dtype=(self.dtype if convert_dtype else None), - check_cols=self.n_cols) + check_cols=self.n_features_in_) cdef uintptr_t X_ptr = X_m.ptr cdef uintptr_t coef_ptr = self.coef_.ptr diff --git a/python/cuml/linear_model/elastic_net.pyx b/python/cuml/linear_model/elastic_net.pyx index 79ed0bd6e6..455b7b92bf 100644 --- a/python/cuml/linear_model/elastic_net.pyx +++ b/python/cuml/linear_model/elastic_net.pyx @@ -19,14 +19,14 @@ from inspect import signature from cuml.solvers import CD, QN -from cuml.common.base import Base -from cuml.common.mixins import RegressorMixin +from cuml.experimental.common.base import Base +from cuml.common.mixins import RegressorMixin, FMajorInputTagMixin from cuml.common.doc_utils import generate_docstring from cuml.common.array import CumlArray from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.logger import warn -from cuml.common.mixins import FMajorInputTagMixin from cuml.linear_model.base import LinearPredictMixin +from cuml.internals.api_decorators import device_interop_preparation class ElasticNet(Base, @@ -145,8 +145,10 @@ class ElasticNet(Base, `_. """ - coef_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.linear_model.ElasticNet' + coef_ = CumlArrayDescriptor(order='F') + @device_interop_preparation def __init__(self, *, alpha=1.0, l1_ratio=0.5, fit_intercept=True, normalize=False, max_iter=1000, tol=1e-3, solver='cd', selection='cyclic', @@ -234,21 +236,26 @@ class ElasticNet(Base, raise ValueError(msg.format(l1_ratio)) @generate_docstring() - def fit(self, X, y, convert_dtype=True, - sample_weight=None) -> "ElasticNet": + def _fit(self, X, y, convert_dtype=True, + sample_weight=None) -> "ElasticNet": """ Fit the model with X and y. """ + self.n_features_in_ = X.shape[1] if X.ndim == 2 else 1 + if hasattr(X, 'index'): + self.feature_names_in_ = X.index + self.solver_model.fit(X, y, convert_dtype=convert_dtype, sample_weight=sample_weight) if isinstance(self.solver_model, QN): + coefs = self.solver_model.coef_ self.coef_ = CumlArray( - data=self.solver_model.coef_, - index=self.solver_model.coef_._index, - dtype=self.solver_model.coef_.dtype, - order=self.solver_model.coef_.order, - shape=(self.solver_model.coef_.shape[0],) + data=coefs, + index=coefs._index, + dtype=coefs.dtype, + order=coefs.order, + shape=(coefs.shape[1],) ) self.intercept_ = self.solver_model.intercept_.item() @@ -273,3 +280,6 @@ class ElasticNet(Base, "solver", "selection", ] + + def get_attr_names(self): + return ['intercept_', 'coef_', 'n_features_in_', 'feature_names_in_'] diff --git a/python/cuml/linear_model/lasso.py b/python/cuml/linear_model/lasso.py index 143d3013f0..db16f4366a 100644 --- a/python/cuml/linear_model/lasso.py +++ b/python/cuml/linear_model/lasso.py @@ -15,6 +15,7 @@ # from cuml.linear_model.elastic_net import ElasticNet +from cuml.internals.api_decorators import device_interop_preparation class Lasso(ElasticNet): @@ -127,6 +128,9 @@ class Lasso(ElasticNet): `_. """ + _cpu_estimator_import_path = 'sklearn.linear_model.Lasso' + + @device_interop_preparation def __init__(self, *, alpha=1.0, fit_intercept=True, normalize=False, max_iter=1000, tol=1e-3, solver='cd', selection='cyclic', diff --git a/python/cuml/linear_model/linear_regression.pyx b/python/cuml/linear_model/linear_regression.pyx index cfac311d74..eb64ac6d43 100644 --- a/python/cuml/linear_model/linear_regression.pyx +++ b/python/cuml/linear_model/linear_regression.pyx @@ -30,17 +30,17 @@ from libc.stdlib cimport calloc, malloc, free from cuml.common.array import CumlArray from cuml.common.array_descriptor import CumlArrayDescriptor -from cuml.experimental.common.base import UniversalBase -from cuml.common.mixins import RegressorMixin +from cuml.experimental.common.base import Base +from cuml.common.mixins import RegressorMixin, FMajorInputTagMixin from cuml.common.doc_utils import generate_docstring from cuml.linear_model.base import LinearPredictMixin from pylibraft.common.handle cimport handle_t from pylibraft.common.handle import Handle from cuml.common import input_to_cuml_array +from cuml.internals.api_decorators import device_interop_preparation + from cuml.common.mixins import FMajorInputTagMixin from cuml.common.input_utils import input_to_cupy_array -from cuml.experimental.common import enable_cpu - cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": @@ -118,7 +118,7 @@ def fit_multi_target(X, y, fit_intercept=True, sample_weight=None): return coef, intercept -class LinearRegression(UniversalBase, +class LinearRegression(Base, RegressorMixin, LinearPredictMixin, FMajorInputTagMixin): @@ -240,10 +240,11 @@ class LinearRegression(UniversalBase, `__. """ - sk_import_path_ = 'sklearn.linear_model' - coef_ = CumlArrayDescriptor() - intercept_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.linear_model.LinearRegression' + coef_ = CumlArrayDescriptor(order='F') + intercept_ = CumlArrayDescriptor(order='F') + @device_interop_preparation def __init__(self, *, algorithm='eig', fit_intercept=True, normalize=False, handle=None, verbose=False, output_type=None): if handle is None and algorithm == 'eig': @@ -279,17 +280,17 @@ class LinearRegression(UniversalBase, }[algorithm] @generate_docstring() - @enable_cpu - def fit(self, X, y, convert_dtype=True, - sample_weight=None) -> "LinearRegression": + def _fit(self, X, y, convert_dtype=True, + sample_weight=None) -> "LinearRegression": """ Fit the model with X and y. """ cdef uintptr_t X_ptr, y_ptr, sample_weight_ptr - X_m, n_rows, self.n_cols, self.dtype = \ + X_m, n_rows, self.n_features_in_, self.dtype = \ input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) X_ptr = X_m.ptr + self.feature_names_in_ = X_m.index y_m, _, y_cols, _ = \ input_to_cuml_array(y, check_dtype=self.dtype, @@ -308,7 +309,7 @@ class LinearRegression(UniversalBase, else: sample_weight_ptr = 0 - if self.n_cols < 1: + if self.n_features_in_ < 1: msg = "X matrix must have at least a column" raise TypeError(msg) @@ -316,7 +317,7 @@ class LinearRegression(UniversalBase, msg = "X matrix must have at least two rows" raise TypeError(msg) - if self.n_cols == 1 and self.algo != 0: + if self.n_features_in_ == 1 and self.algo != 0: warnings.warn("Changing solver from 'eig' to 'svd' as eig " + "solver does not support training data with 1 " + "column currently.", UserWarning) @@ -330,7 +331,7 @@ class LinearRegression(UniversalBase, X_m, y_m, convert_dtype, sample_weight_m ) - self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) + self.coef_ = CumlArray.zeros(self.n_features_in_, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr cdef float c_intercept1 @@ -342,7 +343,7 @@ class LinearRegression(UniversalBase, olsFit(handle_[0], X_ptr, n_rows, - self.n_cols, + self.n_features_in_, y_ptr, coef_ptr, &c_intercept1, @@ -356,7 +357,7 @@ class LinearRegression(UniversalBase, olsFit(handle_[0], X_ptr, n_rows, - self.n_cols, + self.n_features_in_, y_ptr, coef_ptr, &c_intercept2, @@ -416,7 +417,7 @@ class LinearRegression(UniversalBase, self.coef_, _, _, _ = input_to_cuml_array( coef, check_dtype=self.dtype, - check_rows=self.n_cols, + check_rows=self.n_features_in_, check_cols=y_cols ) if self.fit_intercept: @@ -431,21 +432,12 @@ class LinearRegression(UniversalBase, return self - @enable_cpu - def predict(self, X, convert_dtype=True) -> CumlArray: - self.dtype = self.coef_.dtype - self.n_cols = self.coef_.shape[0] - # Adding Base here skips it in the Method Resolution Order (MRO) - # Since Base and LinearPredictMixin now both have a `predict` method - return super(UniversalBase, self).predict(X, - convert_dtype=convert_dtype) - def get_param_names(self): return super().get_param_names() + \ ['algorithm', 'fit_intercept', 'normalize'] - def get_attributes_names(self): - return ['coef_', 'intercept_'] + def get_attr_names(self): + return ['coef_', 'intercept_', 'n_features_in_', 'feature_names_in_'] @staticmethod def _more_static_tags(): diff --git a/python/cuml/linear_model/logistic_regression.pyx b/python/cuml/linear_model/logistic_regression.pyx index 80e2bc9de5..3b4c468906 100644 --- a/python/cuml/linear_model/logistic_regression.pyx +++ b/python/cuml/linear_model/logistic_regression.pyx @@ -22,15 +22,16 @@ import pprint import cuml.internals from cuml.solvers import QN -from cuml.common.base import Base -from cuml.common.mixins import ClassifierMixin +from cuml.experimental.common.base import Base +from cuml.common.mixins import ClassifierMixin, \ + FMajorInputTagMixin from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.array import CumlArray from cuml.common.doc_utils import generate_docstring import cuml.common.logger as logger from cuml.common import input_to_cuml_array from cuml.common import using_output_type -from cuml.common.mixins import FMajorInputTagMixin +from cuml.internals.api_decorators import device_interop_preparation supported_penalties = ["l1", "l2", "none", "elasticnet"] @@ -85,9 +86,8 @@ class LogisticRegression(Base, >>> reg.fit(X,y) LogisticRegression() >>> print(reg.coef_) - 0 0.698... - 1 0.570... - dtype: float32 + 0 1 + 0 0.69861 0.570058 >>> print(reg.intercept_) 0 -2.188... dtype: float32 @@ -123,8 +123,6 @@ class LogisticRegression(Base, fit_intercept : boolean (default = True) If True, the model tries to correct for the global mean of y. If False, the model expects that you have centered the data. - class_weight : None - Custom class weighs are currently not supported. class_weight : dict or 'balanced', default=None By default all classes have a weight one. However, a dictionary can be provided with weights associated with classes @@ -182,10 +180,12 @@ class LogisticRegression(Base, `_. """ - classes_ = CumlArrayDescriptor() - class_weight_ = CumlArrayDescriptor() - expl_spec_weights_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.linear_model.LogisticRegression' + classes_ = CumlArrayDescriptor(order='F') + class_weight = CumlArrayDescriptor(order='F') + expl_spec_weights_ = CumlArrayDescriptor(order='F') + @device_interop_preparation def __init__( self, *, @@ -239,7 +239,7 @@ class LogisticRegression(Base, if class_weight is not None: self._build_class_weights(class_weight) else: - self.class_weight_ = None + self.class_weight = None self.solver_model = QN( loss=loss, @@ -262,12 +262,16 @@ class LogisticRegression(Base, @generate_docstring(X='dense_sparse') @cuml.internals.api_base_return_any(set_output_dtype=True) - def fit(self, X, y, sample_weight=None, - convert_dtype=True) -> "LogisticRegression": + def _fit(self, X, y, sample_weight=None, + convert_dtype=True) -> "LogisticRegression": """ Fit the model with X and y. """ + self.n_features_in_ = X.shape[1] if X.ndim == 2 else 1 + if hasattr(X, 'index'): + self.feature_names_in_ = X.index + # Converting y to device array here to use `unique` function # since calling input_to_cuml_array again in QN has no cost # Not needed to check dtype since qn class checks it already @@ -280,7 +284,7 @@ class LogisticRegression(Base, raise ValueError("Only values of 0 and 1 are" " supported for binary classification.") - if sample_weight is not None or self.class_weight_ is not None: + if sample_weight is not None or self.class_weight is not None: if sample_weight is None: sample_weight = cp.ones(n_rows) @@ -299,22 +303,22 @@ class LogisticRegression(Base, msg = "Class label {} not present.".format(c) raise ValueError(msg) - if self.class_weight_ is not None: - if self.class_weight_ == 'balanced': + if self.class_weight is not None: + if self.class_weight == 'balanced': class_weight = n_rows / \ (self._num_classes * cp.bincount(y_m.to_output('cupy'))) class_weight = CumlArray(class_weight) else: check_expl_spec_weights() - n_explicit = self.class_weight_.shape[0] + n_explicit = self.class_weight.shape[0] if n_explicit != self._num_classes: class_weight = cp.ones(self._num_classes) - class_weight[:n_explicit] = self.class_weight_ + class_weight[:n_explicit] = self.class_weight class_weight = CumlArray(class_weight) - self.class_weight_ = class_weight + self.class_weight = class_weight else: - class_weight = self.class_weight_ + class_weight = self.class_weight out = y_m.to_output('cupy') sample_weight *= class_weight[out].to_output('cupy') sample_weight = CumlArray(sample_weight) @@ -359,7 +363,7 @@ class LogisticRegression(Base, 'type': 'dense', 'description': 'Confidence score', 'shape': '(n_samples, n_classes)'}) - def decision_function(self, X, convert_dtype=False) -> CumlArray: + def _decision_function(self, X, convert_dtype=True) -> CumlArray: """ Gives confidence score for X @@ -375,7 +379,7 @@ class LogisticRegression(Base, 'description': 'Predicted values', 'shape': '(n_samples, 1)'}) @cuml.internals.api_base_return_array(get_output_dtype=True) - def predict(self, X, convert_dtype=True) -> CumlArray: + def _predict(self, X, convert_dtype=True) -> CumlArray: """ Predicts the y for X. @@ -388,7 +392,7 @@ class LogisticRegression(Base, 'description': 'Predicted class \ probabilities', 'shape': '(n_samples, n_classes)'}) - def predict_proba(self, X, convert_dtype=True) -> CumlArray: + def _predict_proba(self, X, convert_dtype=True) -> CumlArray: """ Predicts the class probabilities for each class in X """ @@ -404,7 +408,7 @@ class LogisticRegression(Base, 'description': 'Logaright of predicted \ class probabilities', 'shape': '(n_samples, n_classes)'}) - def predict_log_proba(self, X, convert_dtype=True) -> CumlArray: + def _predict_log_proba(self, X, convert_dtype=True) -> CumlArray: """ Predicts the log class probabilities for each class in X @@ -419,14 +423,16 @@ class LogisticRegression(Base, X, convert_dtype=False, log_proba=False) -> CumlArray: + _num_classes = self.classes_.shape[0] + scores = cp.asarray( self.decision_function(X, convert_dtype=convert_dtype), order="F" ).T - if self._num_classes == 2: + if _num_classes == 2: proba = cp.zeros((scores.shape[0], 2)) proba[:, 1] = 1 / (1 + cp.exp(-scores.ravel())) proba[:, 0] = 1 - proba[:, 1] - elif self._num_classes > 2: + elif _num_classes > 2: max_scores = cp.max(scores, axis=1).reshape((-1, 1)) scores -= max_scores proba = cp.exp(scores) @@ -459,14 +465,14 @@ class LogisticRegression(Base, def _build_class_weights(self, class_weight): if class_weight == 'balanced': - self.class_weight_ = 'balanced' + self.class_weight = 'balanced' else: classes = list(class_weight.keys()) weights = list(class_weight.values()) max_class = sorted(classes)[-1] class_weight = cp.ones(max_class + 1) class_weight[classes] = weights - self.class_weight_, _, _, _ = input_to_cuml_array(class_weight) + self.class_weight, _, _, _ = input_to_cuml_array(class_weight) self.expl_spec_weights_, _, _, _ = \ input_to_cuml_array(np.array(classes)) @@ -492,6 +498,24 @@ class LogisticRegression(Base, self.solver_model.set_params(**params) return self + @property + @cuml.internals.api_base_return_array_skipall + def coef_(self): + return self.solver_model.coef_ + + @coef_.setter + def coef_(self, value): + self.solver_model.coef_ = value + + @property + @cuml.internals.api_base_return_array_skipall + def intercept_(self): + return self.solver_model.intercept_ + + @intercept_.setter + def intercept_(self, value): + self.solver_model.intercept_ = value + def get_param_names(self): return super().get_param_names() + [ "penalty", @@ -513,3 +537,7 @@ class LogisticRegression(Base, super().__init__(handle=None, verbose=state["verbose"]) self.__dict__.update(state) + + def get_attr_names(self): + return ['classes_', 'intercept_', 'coef_', 'n_features_in_', + 'feature_names_in_'] diff --git a/python/cuml/linear_model/ridge.pyx b/python/cuml/linear_model/ridge.pyx index 8bb37620b3..d85bab533a 100644 --- a/python/cuml/linear_model/ridge.pyx +++ b/python/cuml/linear_model/ridge.pyx @@ -27,14 +27,14 @@ from libc.stdint cimport uintptr_t from libc.stdlib cimport calloc, malloc, free from cuml.common.array_descriptor import CumlArrayDescriptor -from cuml.common.base import Base -from cuml.common.mixins import RegressorMixin +from cuml.experimental.common.base import Base +from cuml.common.mixins import RegressorMixin, FMajorInputTagMixin from cuml.common.array import CumlArray from cuml.common.doc_utils import generate_docstring from cuml.linear_model.base import LinearPredictMixin from pylibraft.common.handle cimport handle_t from cuml.common import input_to_cuml_array -from cuml.common.mixins import FMajorInputTagMixin +from cuml.internals.api_decorators import device_interop_preparation cdef extern from "cuml/linear_model/glm.hpp" namespace "ML::GLM": @@ -185,9 +185,11 @@ class Ridge(Base, `_. """ - coef_ = CumlArrayDescriptor() - intercept_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.linear_model.Ridge' + coef_ = CumlArrayDescriptor(order='F') + intercept_ = CumlArrayDescriptor(order='F') + @device_interop_preparation def __init__(self, *, alpha=1.0, solver='eig', fit_intercept=True, normalize=False, handle=None, output_type=None, verbose=False): @@ -238,15 +240,17 @@ class Ridge(Base, }[algorithm] @generate_docstring() - def fit(self, X, y, convert_dtype=True, sample_weight=None) -> "Ridge": + def _fit(self, X, y, convert_dtype=True, sample_weight=None) -> "Ridge": """ Fit the model with X and y. """ cdef uintptr_t X_ptr, y_ptr, sample_weight_ptr - X_m, n_rows, self.n_cols, self.dtype = \ - input_to_cuml_array(X, check_dtype=[np.float32, np.float64]) + X_m, n_rows, self.n_features_in_, self.dtype = \ + input_to_cuml_array(X, deepcopy=True, + check_dtype=[np.float32, np.float64]) X_ptr = X_m.ptr + self.feature_names_in_ = X_m.index y_m, _, _, _ = \ input_to_cuml_array(y, check_dtype=self.dtype, @@ -265,7 +269,7 @@ class Ridge(Base, else: sample_weight_ptr = 0 - if self.n_cols < 1: + if self.n_features_in_ < 1: msg = "X matrix must have at least a column" raise TypeError(msg) @@ -273,7 +277,7 @@ class Ridge(Base, msg = "X matrix must have at least two rows" raise TypeError(msg) - if self.n_cols == 1 and self.algo != 0: + if self.n_features_in_ == 1 and self.algo != 0: warnings.warn("Changing solver to 'svd' as 'eig' or 'cd' " + "solvers do not support training data with 1 " + "column currently.", UserWarning) @@ -281,7 +285,7 @@ class Ridge(Base, self.n_alpha = 1 - self.coef_ = CumlArray.zeros(self.n_cols, dtype=self.dtype) + self.coef_ = CumlArray.zeros(self.n_features_in_, dtype=self.dtype) cdef uintptr_t coef_ptr = self.coef_.ptr cdef float c_intercept1 @@ -295,7 +299,7 @@ class Ridge(Base, ridgeFit(handle_[0], X_ptr, n_rows, - self.n_cols, + self.n_features_in_, y_ptr, &c_alpha1, self.n_alpha, @@ -309,11 +313,10 @@ class Ridge(Base, self.intercept_ = c_intercept1 else: c_alpha2 = self.alpha - ridgeFit(handle_[0], X_ptr, n_rows, - self.n_cols, + self.n_features_in_, y_ptr, &c_alpha2, self.n_alpha, @@ -348,3 +351,6 @@ class Ridge(Base, def get_param_names(self): return super().get_param_names() + \ ['solver', 'fit_intercept', 'normalize', 'alpha'] + + def get_attr_names(self): + return ['intercept_', 'coef_', 'n_features_in_', 'feature_names_in_'] diff --git a/python/cuml/manifold/umap.pyx b/python/cuml/manifold/umap.pyx index 5f6a9af5c9..23e1a88407 100644 --- a/python/cuml/manifold/umap.pyx +++ b/python/cuml/manifold/umap.pyx @@ -38,7 +38,8 @@ from cupyx.scipy.sparse import csr_matrix as cp_csr_matrix,\ import cuml.internals from cuml.common import using_output_type -from cuml.common.base import Base +from cuml.experimental.common.base import Base +from cuml.common.mixins import CMajorInputTagMixin from pylibraft.common.handle cimport handle_t from cuml.common.doc_utils import generate_docstring from cuml.common import logger @@ -47,7 +48,6 @@ from cuml.common.memory_utils import using_output_type from cuml.common.import_utils import has_scipy from cuml.common.array import CumlArray from cuml.common.array_sparse import SparseCumlArray -from cuml.common.mixins import CMajorInputTagMixin from cuml.common.sparse_utils import is_sparse from cuml.metrics.distance_type cimport DistanceType @@ -58,6 +58,7 @@ if has_scipy(True): import scipy.sparse from cuml.common.array_descriptor import CumlArrayDescriptor +from cuml.internals.api_decorators import device_interop_preparation import rmm @@ -299,9 +300,10 @@ class UMAP(Base, `_ """ - X_m = CumlArrayDescriptor() - embedding_ = CumlArrayDescriptor() + _cpu_estimator_import_path = 'umap.UMAP' + embedding_ = CumlArrayDescriptor(order='C') + @device_interop_preparation def __init__(self, *, n_neighbors=15, n_components=2, @@ -339,7 +341,7 @@ class UMAP(Base, self.n_components = n_components self.metric = metric self.metric_kwds = metric_kwds - self.n_epochs = n_epochs if n_epochs else 0 + self.n_epochs = n_epochs if init == "spectral" or init == "random": self.init = init @@ -379,8 +381,8 @@ class UMAP(Base, rs = np.random.RandomState(random_state) self.random_state = rs.randint(low=0, - high=np.iinfo(np.uint64).max, - dtype=np.uint64) + high=np.iinfo(np.uint32).max, + dtype=np.uint32) if target_metric == "euclidean" or target_metric == "categorical": self.target_metric = target_metric @@ -388,12 +390,13 @@ class UMAP(Base, raise Exception("Invalid target metric: {}" % target_metric) self.callback = callback # prevent callback destruction - self.X_m = None self.embedding_ = None self.validate_hyperparams() self.sparse_fit = False + self._input_hash = None + self._small_data = False def validate_hyperparams(self): @@ -405,7 +408,7 @@ class UMAP(Base, cdef UMAPParams* umap_params = new UMAPParams() umap_params.n_neighbors = cls.n_neighbors umap_params.n_components = cls.n_components - umap_params.n_epochs = cls.n_epochs + umap_params.n_epochs = cls.n_epochs if cls.n_epochs else 0 umap_params.learning_rate = cls.learning_rate umap_params.min_dist = cls.min_dist umap_params.spread = cls.spread @@ -480,8 +483,8 @@ class UMAP(Base, @generate_docstring(convert_dtype_cast='np.float32', X='dense_sparse', skip_parameters_heading=True) - def fit(self, X, y=None, convert_dtype=True, - knn_graph=None) -> "UMAP": + def _fit(self, X, y=None, convert_dtype=True, + knn_graph=None) -> "UMAP": """ Fit X into an embedded space. @@ -517,14 +520,14 @@ class UMAP(Base, # Handle sparse inputs if is_sparse(X): - self.X_m = SparseCumlArray(X, convert_to_dtype=cupy.float32, - convert_format=False) - self.n_rows, self.n_dims = self.X_m.shape + self._raw_data = SparseCumlArray(X, convert_to_dtype=cupy.float32, + convert_format=False) + self.n_rows, self.n_dims = self._raw_data.shape self.sparse_fit = True # Handle dense inputs else: - self.X_m, self.n_rows, self.n_dims, dtype = \ + self._raw_data, self.n_rows, self.n_dims, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype @@ -545,11 +548,10 @@ class UMAP(Base, self.embedding_ = CumlArray.zeros((self.n_rows, self.n_components), order="C", dtype=np.float32, - index=self.X_m.index) + index=self._raw_data.index) if self.hash_input: - with using_output_type("numpy"): - self.input_hash = joblib.hash(self.X_m) + self._input_hash = joblib.hash(self._raw_data.to_output('numpy')) cdef handle_t * handle_ = \ self.handle.getHandle() @@ -572,10 +574,10 @@ class UMAP(Base, fss_graph = GraphHolder.new_graph(handle_.get_stream()) if self.sparse_fit: fit_sparse(handle_[0], - self.X_m.indptr.ptr, - self.X_m.indices.ptr, - self.X_m.data.ptr, - self.X_m.nnz, + self._raw_data.indptr.ptr, + self._raw_data.indices.ptr, + self._raw_data.data.ptr, + self._raw_data.nnz, y_raw, self.n_rows, self.n_dims, @@ -585,7 +587,7 @@ class UMAP(Base, else: fit(handle_[0], - self.X_m.ptr, + self._raw_data.ptr, y_raw, self.n_rows, self.n_dims, @@ -612,8 +614,8 @@ class UMAP(Base, low-dimensional space.', 'shape': '(n_samples, n_components)'}) @cuml.internals.api_base_fit_transform() - def fit_transform(self, X, y=None, convert_dtype=True, - knn_graph=None) -> CumlArray: + def _fit_transform(self, X, y=None, convert_dtype=True, + knn_graph=None) -> CumlArray: """ Fit X into an embedded space and return that transformed output. @@ -658,7 +660,7 @@ class UMAP(Base, data in \ low-dimensional space.', 'shape': '(n_samples, n_components)'}) - def transform(self, X, convert_dtype=True, knn_graph=None) -> CumlArray: + def _transform(self, X, convert_dtype=True, knn_graph=None) -> CumlArray: """ Transform X into the existing embedded space and return that transformed output. @@ -721,15 +723,14 @@ class UMAP(Base, n_rows = X_m.shape[0] n_cols = X_m.shape[1] - if n_cols != self.n_dims: + if n_cols != self._raw_data.shape[1]: raise ValueError("n_features of X must match n_features of " "training data") - if self.hash_input and joblib.hash(X_m.to_output('numpy')) == \ - self.input_hash: - - del X_m - return self.embedding_ + if self.hash_input: + if joblib.hash(X_m.to_output('numpy')) == self._input_hash: + del X_m + return self.embedding_ embedding = CumlArray.zeros((X_m.shape[0], self.n_components), @@ -759,13 +760,13 @@ class UMAP(Base, X_m.nnz, X_m.shape[0], X_m.shape[1], - self.X_m.indptr.ptr, - self.X_m.indices.ptr, - self.X_m.data.ptr, - self.X_m.nnz, - self.X_m.shape[0], + self._raw_data.indptr.ptr, + self._raw_data.indices.ptr, + self._raw_data.data.ptr, + self._raw_data.nnz, + self._raw_data.shape[0], embed_ptr, - self.n_rows, + self._raw_data.shape[0], umap_params, xformed_ptr) else: @@ -775,10 +776,10 @@ class UMAP(Base, X_m.shape[1], knn_indices_raw, knn_dists_raw, - self.X_m.ptr, - self.n_rows, + self._raw_data.ptr, + self._raw_data.shape[0], embed_ptr, - self.n_rows, + self._raw_data.shape[0], umap_params, xformed_ptr) self.handle.sync() @@ -813,3 +814,6 @@ class UMAP(Base, "metric", "metric_kwds" ] + + def get_attr_names(self): + return ['_raw_data', 'embedding_', '_input_hash', '_small_data'] diff --git a/python/cuml/neighbors/kneighbors_classifier.pyx b/python/cuml/neighbors/kneighbors_classifier.pyx index c73d7551da..44da30ea15 100644 --- a/python/cuml/neighbors/kneighbors_classifier.pyx +++ b/python/cuml/neighbors/kneighbors_classifier.pyx @@ -75,9 +75,9 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": ) except + -class KNeighborsClassifier(NearestNeighbors, - ClassifierMixin, - FMajorInputTagMixin): +class KNeighborsClassifier(ClassifierMixin, + FMajorInputTagMixin, + NearestNeighbors): """ K-Nearest Neighbors Classifier is an instance-based learning technique, that keeps training samples around for prediction, rather than trying @@ -222,7 +222,7 @@ class KNeighborsClassifier(NearestNeighbors, classes_ptr, inds_ctype, deref(y_vec), - self.n_rows, + self.n_samples_fit_, n_rows, self.n_neighbors ) @@ -286,7 +286,7 @@ class KNeighborsClassifier(NearestNeighbors, deref(out_vec), inds_ctype, deref(y_vec), - self.n_rows, + self.n_samples_fit_, n_rows, self.n_neighbors ) diff --git a/python/cuml/neighbors/kneighbors_regressor.pyx b/python/cuml/neighbors/kneighbors_regressor.pyx index 4b372b98df..862eef2372 100644 --- a/python/cuml/neighbors/kneighbors_regressor.pyx +++ b/python/cuml/neighbors/kneighbors_regressor.pyx @@ -63,9 +63,9 @@ cdef extern from "cuml/neighbors/knn.hpp" namespace "ML": ) except + -class KNeighborsRegressor(NearestNeighbors, - RegressorMixin, - FMajorInputTagMixin): +class KNeighborsRegressor(RegressorMixin, + FMajorInputTagMixin, + NearestNeighbors): """ K-Nearest Neighbors Regressor is an instance-based learning technique, @@ -230,7 +230,7 @@ class KNeighborsRegressor(NearestNeighbors, results_ptr, inds_ctype, deref(y_vec), - self.n_rows, + self.n_samples_fit_, n_rows, self.n_neighbors ) diff --git a/python/cuml/neighbors/nearest_neighbors.pyx b/python/cuml/neighbors/nearest_neighbors.pyx index e3f25215b2..3c6d3ffeaa 100644 --- a/python/cuml/neighbors/nearest_neighbors.pyx +++ b/python/cuml/neighbors/nearest_neighbors.pyx @@ -26,19 +26,20 @@ import warnings import math import cuml.internals -from cuml.common.base import Base +from cuml.experimental.common.base import Base +from cuml.common.mixins import CMajorInputTagMixin from cuml.common.array_descriptor import CumlArrayDescriptor from cuml.common.array import CumlArray from cuml.common.array_sparse import SparseCumlArray from cuml.common.doc_utils import generate_docstring from cuml.common.doc_utils import insert_into_docstring from cuml.common.import_utils import has_scipy -from cuml.common.mixins import CMajorInputTagMixin from cuml.common.input_utils import input_to_cupy_array from cuml.common import input_to_cuml_array from cuml.common.sparse_utils import is_sparse from cuml.common.sparse_utils import is_dense from cuml.metrics.distance_type cimport DistanceType +from cuml.internals.api_decorators import device_interop_preparation from cuml.neighbors.ann cimport * from pylibraft.common.handle cimport handle_t @@ -305,8 +306,10 @@ class NearestNeighbors(Base, """ - X_m = CumlArrayDescriptor() + _cpu_estimator_import_path = 'sklearn.neighbors.NearestNeighbors' + _fit_X = CumlArrayDescriptor(order='C') + @device_interop_preparation def __init__(self, *, n_neighbors=5, verbose=False, @@ -325,18 +328,18 @@ class NearestNeighbors(Base, self.n_neighbors = n_neighbors self.n_indices = 0 - self.metric = metric - self.metric_params = metric_params + self.effective_metric_ = metric + self.effective_metric_params_ = metric_params if metric_params else {} self.algo_params = algo_params self.p = p self.algorithm = algorithm - self.working_algorithm_ = self.algorithm + self._fit_method = self.algorithm self.selected_algorithm_ = algorithm self.algo_params = algo_params self.knn_index = None @generate_docstring(X='dense_sparse') - def fit(self, X, convert_dtype=True) -> "NearestNeighbors": + def _fit(self, X, convert_dtype=True) -> "NearestNeighbors": """ Fit GPU index for performing nearest neighbor queries. @@ -344,50 +347,51 @@ class NearestNeighbors(Base, if len(X.shape) != 2: raise ValueError("data should be two dimensional") - self.n_dims = X.shape[1] + self.n_samples_fit_, self.n_features_in_ = X.shape if self.algorithm == "auto": - if (self.n_dims == 2 or self.n_dims == 3) and \ - not is_sparse(X) and \ - self.metric in cuml.neighbors.VALID_METRICS["rbc"] and \ + if (self.n_features_in_ == 2 or self.n_features_in_ == 3) and \ + not is_sparse(X) and self.effective_metric_ in \ + cuml.neighbors.VALID_METRICS["rbc"] and \ math.sqrt(X.shape[0]) >= self.n_neighbors: - self.working_algorithm_ = "rbc" + self._fit_method = "rbc" else: - self.working_algorithm_ = "brute" + self._fit_method = "brute" - if self.algorithm == "rbc" and self.n_dims > 3: + if self.algorithm == "rbc" and self.n_features_in_ > 3: raise ValueError("The rbc algorithm is not supported for" " >3 dimensions currently.") if is_sparse(X): valid_metrics = cuml.neighbors.VALID_METRICS_SPARSE valid_metric_str = "_SPARSE" - self.X_m = SparseCumlArray(X, convert_to_dtype=cp.float32, - convert_format=False) - self.n_rows = self.X_m.shape[0] + self._fit_X = SparseCumlArray(X, convert_to_dtype=cp.float32, + convert_format=False) else: valid_metrics = cuml.neighbors.VALID_METRICS valid_metric_str = "" - self.X_m, self.n_rows, n_cols, dtype = \ + self._fit_X, _, _, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, convert_to_dtype=(np.float32 if convert_dtype else None)) - self._output_index = self.X_m.index - if self.metric not in \ - valid_metrics[self.working_algorithm_]: + self._output_index = self._fit_X.index + self.feature_names_in_ = self._fit_X.index + + if self.effective_metric_ not in \ + valid_metrics[self._fit_method]: raise ValueError("Metric %s is not valid. " "Use sorted(cuml.neighbors.VALID_METRICS%s[%s]) " "to get valid options." % (valid_metric_str, - self.metric, - self.working_algorithm_)) + self.effective_metric_, + self._fit_method)) cdef handle_t* handle_ = self.handle.getHandle() cdef knnIndexParam* algo_params = 0 - if self.working_algorithm_ in ['ivfflat', 'ivfpq', 'ivfsq']: + if self._fit_method in ['ivfflat', 'ivfpq', 'ivfsq']: warnings.warn("\nWarning: Approximate Nearest Neighbor methods " "might be unstable in this version of cuML. " "This is due to a known issue in the FAISS " @@ -398,34 +402,34 @@ class NearestNeighbors(Base, raise ValueError("Approximate Nearest Neigbors methods " "require dense data") - additional_info = {'n_samples': self.n_rows, - 'n_features': n_cols} + additional_info = {'n_samples': self.n_samples_fit_, + 'n_features': self.n_features_in_} knn_index = new knnIndex() self.knn_index = knn_index algo_params = \ - build_algo_params(self.working_algorithm_, self.algo_params, + build_algo_params(self._fit_method, self.algo_params, additional_info) - metric = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.effective_metric_) approx_knn_build_index(handle_[0], knn_index, algo_params, metric, self.p, - self.X_m.ptr, - self.n_rows, - n_cols) + self._fit_X.ptr, + self.n_samples_fit_, + self.n_features_in_) self.handle.sync() destroy_algo_params(algo_params) - del self.X_m - elif self.working_algorithm_ == "rbc": - metric = self._build_metric_type(self.metric) + del self._fit_X + elif self._fit_method == "rbc": + metric = self._build_metric_type(self.effective_metric_) rbc_index = new BallCoverIndex[int64_t, float, uint32_t]( - handle_[0], self.X_m.ptr, - self.n_rows, n_cols, + handle_[0], self._fit_X.ptr, + self.n_samples_fit_, self.n_features_in_, metric) rbc_build_index(handle_[0], deref(rbc_index)) @@ -439,6 +443,11 @@ class NearestNeighbors(Base, ["n_neighbors", "algorithm", "metric", "p", "metric_params", "algo_params", "n_jobs"] + def get_attr_names(self): + return ['_fit_X', 'effective_metric_', 'effective_metric_params_', + 'n_samples_fit_', 'n_features_in_', 'feature_names_in_', + '_fit_method'] + @staticmethod def _build_metric_type(metric): if metric == "euclidean" or metric == "l2": @@ -478,7 +487,7 @@ class NearestNeighbors(Base, return_values=[('dense', '(n_samples, n_features)'), ('dense', '(n_samples, n_features)')]) - def kneighbors( + def _kneighbors( self, X=None, n_neighbors=None, @@ -533,12 +542,13 @@ class NearestNeighbors(Base, The indices of the k-nearest neighbors for each column vector in X """ - return self._kneighbors(X, n_neighbors, return_distance, convert_dtype, - two_pass_precision=two_pass_precision) + return self._kneighbors_internal(X, n_neighbors, return_distance, + convert_dtype, + two_pass_precision=two_pass_precision) - def _kneighbors(self, X=None, n_neighbors=None, return_distance=True, - convert_dtype=True, _output_type=None, - two_pass_precision=False): + def _kneighbors_internal(self, X=None, n_neighbors=None, + return_distance=True, convert_dtype=True, + _output_type=None, two_pass_precision=False): """ Query the GPU index for the k nearest neighbors of column vectors in X. @@ -596,14 +606,14 @@ class NearestNeighbors(Base, use_training_data = X is None if X is None: - X = self.X_m + X = self._fit_X n_neighbors += 1 if (n_neighbors is None and self.n_neighbors is None) \ or n_neighbors <= 0: raise ValueError("k or n_neighbors must be a positive integers") - if n_neighbors > self.n_rows: + if n_neighbors > self.n_samples_fit_: raise ValueError("n_neighbors must be <= number of " "samples in index") @@ -611,11 +621,12 @@ class NearestNeighbors(Base, raise ValueError("Model needs to be trained " "before calling kneighbors()") - if X.shape[1] != self.n_dims: + if X.shape[1] != self.n_features_in_: raise ValueError("Dimensions of X need to match dimensions of " - "indices (%d)" % self.n_dims) + "indices (%d)" % self.n_features_in_) - if hasattr(self, 'X_m') and isinstance(self.X_m, SparseCumlArray): + if hasattr(self, '_fit_X') and isinstance(self._fit_X, + SparseCumlArray): D_ndarr, I_ndarr = self._kneighbors_sparse(X, n_neighbors) else: D_ndarr, I_ndarr = self._kneighbors_dense(X, n_neighbors, @@ -627,7 +638,7 @@ class NearestNeighbors(Base, if _output_type is not None else self._get_output_type(X) if two_pass_precision: - metric = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.effective_metric_) metric_is_l2_based = ( metric == DistanceType.L2SqrtExpanded or metric == DistanceType.L2Expanded or @@ -680,7 +691,7 @@ class NearestNeighbors(Base, raise ValueError("A NearestNeighbors model trained on dense " "data requires dense input to kneighbors()") - metric = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.effective_metric_) X_m, N, _, dtype = \ input_to_cuml_array(X, order='C', check_dtype=np.float32, @@ -705,24 +716,24 @@ class NearestNeighbors(Base, cdef BallCoverIndex[int64_t, float, uint32_t]* rbc_index = \ 0 - fallback_to_brute = self.working_algorithm_ == "rbc" and \ - n_neighbors > math.sqrt(self.X_m.shape[0]) + fallback_to_brute = self._fit_method == "rbc" and \ + n_neighbors > math.sqrt(self.n_samples_fit_) if fallback_to_brute: warnings.warn("algorithm='rbc' requires sqrt(%s) be " "> n_neighbors (%s). falling back to " "brute force search" % - (self.X_m.shape[0], n_neighbors)) + (self.n_samples_fit_, n_neighbors)) - if self.working_algorithm_ == 'brute' or fallback_to_brute: - inputs.push_back(self.X_m.ptr) - sizes.push_back(self.X_m.shape[0]) + if self._fit_method == 'brute' or fallback_to_brute: + inputs.push_back(self._fit_X.ptr) + sizes.push_back(self.n_samples_fit_) brute_force_knn( handle_[0], deref(inputs), deref(sizes), - self.n_dims, + self.n_features_in_, X_m.ptr, N, I_ptr, @@ -734,7 +745,7 @@ class NearestNeighbors(Base, # minkowski order is currently the only metric argument. self.p ) - elif self.working_algorithm_ == "rbc": + elif self._fit_method == "rbc": rbc_index = \ self.knn_index rbc_knn_query(handle_[0], @@ -760,7 +771,7 @@ class NearestNeighbors(Base, def _kneighbors_sparse(self, X, n_neighbors): - if isinstance(self.X_m, SparseCumlArray) and not is_sparse(X): + if isinstance(self._fit_X, SparseCumlArray) and not is_sparse(X): raise ValueError("A NearestNeighbors model trained on sparse " "data requires sparse input to kneighbors()") @@ -776,11 +787,11 @@ class NearestNeighbors(Base, X_m = SparseCumlArray(X, convert_to_dtype=cp.float32, convert_format=False) - metric = self._build_metric_type(self.metric) + metric = self._build_metric_type(self.effective_metric_) - cdef uintptr_t idx_indptr = self.X_m.indptr.ptr - cdef uintptr_t idx_indices = self.X_m.indices.ptr - cdef uintptr_t idx_data = self.X_m.data.ptr + cdef uintptr_t idx_indptr = self._fit_X.indptr.ptr + cdef uintptr_t idx_indices = self._fit_X.indices.ptr + cdef uintptr_t idx_data = self._fit_X.data.ptr cdef uintptr_t search_indptr = X_m.indptr.ptr cdef uintptr_t search_indices = X_m.indices.ptr @@ -802,9 +813,9 @@ class NearestNeighbors(Base, idx_indptr, idx_indices, idx_data, - self.X_m.nnz, - self.X_m.shape[0], - self.X_m.shape[1], + self._fit_X.nnz, + self.n_samples_fit_, + self.n_features_in_, search_indptr, search_indices, search_data, @@ -853,7 +864,7 @@ class NearestNeighbors(Base, numpy's CSR sparse graph (host) """ - if not self.X_m: + if not self._fit_X: raise ValueError('This NearestNeighbors instance has not been ' 'fitted yet, call "fit" before using this ' 'estimator') @@ -862,16 +873,16 @@ class NearestNeighbors(Base, n_neighbors = self.n_neighbors if mode == 'connectivity': - indices = self._kneighbors(X, n_neighbors, - return_distance=False, - _output_type="cupy") + indices = self._kneighbors_internal(X, n_neighbors, + return_distance=False, + _output_type="cupy") n_samples = indices.shape[0] distances = cp.ones(n_samples * n_neighbors, dtype=np.float32) elif mode == 'distance': - distances, indices = self._kneighbors(X, n_neighbors, - _output_type="cupy") + distances, indices = self._kneighbors_internal(X, n_neighbors, + _output_type="cupy") distances = cp.ravel(distances) else: @@ -881,7 +892,6 @@ class NearestNeighbors(Base, n_samples = indices.shape[0] indices = cp.ravel(indices) - n_samples_fit = self.X_m.shape[0] n_nonzero = n_samples * n_neighbors rowptr = cp.arange(0, n_nonzero + 1, n_neighbors) @@ -890,7 +900,7 @@ class NearestNeighbors(Base, cp.asarray(indices)), rowptr), shape=(n_samples, - n_samples_fit)) + self.n_samples_fit_)) return sparse_csr @@ -901,7 +911,7 @@ class NearestNeighbors(Base, kidx = self.__dict__['knn_index'] \ if 'knn_index' in self.__dict__ else None if kidx is not None: - if self.working_algorithm_ in ["ivfflat", "ivfpq", "ivfsq"]: + if self._fit_method in ["ivfflat", "ivfpq", "ivfsq"]: knn_index = kidx del knn_index else: @@ -1015,6 +1025,6 @@ def kneighbors_graph(X=None, n_neighbors=5, mode='connectivity', verbose=False, if not include_self: query = None else: - query = X.X_m + query = X._fit_X return X.kneighbors_graph(X=query, n_neighbors=n_neighbors, mode=mode) diff --git a/python/cuml/solvers/qn.pyx b/python/cuml/solvers/qn.pyx index 9ab1f87dff..b55bb444da 100644 --- a/python/cuml/solvers/qn.pyx +++ b/python/cuml/solvers/qn.pyx @@ -431,10 +431,23 @@ class QN(Base, @property @cuml.internals.api_base_return_array_skipall def coef_(self): + if self._coef_ is None: + return None if self.fit_intercept: - return self._coef_[0:-1] + val = self._coef_[0:-1] else: - return self._coef_ + val = self._coef_ + val = val.to_output('cupy') + val = val.T + return val + + @coef_.setter + def coef_(self, value): + value = value.to_output('cupy').T + if self.fit_intercept: + value = cp.vstack([value, self.intercept_]) + value, _, _, _ = input_to_cuml_array(value) + self._coef_ = value @generate_docstring(X='dense_sparse') def fit(self, X, y, sample_weight=None, convert_dtype=False) -> "QN": @@ -640,36 +653,60 @@ class QN(Base, y: array-like (device) Dense matrix (floats or doubles) of shape (n_samples, n_classes) """ + coefs = self.coef_ + dtype = coefs.dtype + _num_classes_dim, n_cols = coefs.shape + sparse_input = is_sparse(X) # Handle sparse inputs if sparse_input: X_m = SparseCumlArray( X, - convert_to_dtype=(self.dtype if convert_dtype else None), + convert_to_dtype=(dtype if convert_dtype else None), convert_index=np.int32 ) n_rows, n_cols = X_m.shape - self.dtype = X_m.dtype + dtype = X_m.dtype # Handle dense inputs else: - X_m, n_rows, n_cols, self.dtype = input_to_cuml_array( - X, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype else None), - check_cols=self.n_cols, + X_m, n_rows, n_cols, dtype = input_to_cuml_array( + X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype else None), + check_cols=n_cols, order='K' ) - scores = CumlArray.zeros(shape=(self._num_classes_dim, n_rows), - dtype=self.dtype, order='F') + if _num_classes_dim > 1: + shape = (_num_classes_dim, n_rows) + else: + shape = (n_rows,) + scores = CumlArray.zeros(shape=shape, dtype=dtype, order='F') cdef uintptr_t coef_ptr = self._coef_.ptr cdef uintptr_t scores_ptr = scores.ptr cdef handle_t* handle_ = self.handle.getHandle() + if not hasattr(self, 'qnparams'): + self.qnparams = QNParams( + loss=self.loss, + penalty_l1=self.l1_strength, + penalty_l2=self.l2_strength, + grad_tol=self.tol, + change_tol=self.delta + if self.delta is not None else (self.tol * 0.01), + max_iter=self.max_iter, + linesearch_max_iter=self.linesearch_max_iter, + lbfgs_memory=self.lbfgs_memory, + verbose=self.verbose, + fit_intercept=self.fit_intercept, + penalty_normalized=self.penalty_normalized + ) + + _num_classes = self.get_num_classes(_num_classes_dim) cdef qn_params qnpams = self.qnparams.params - if self.dtype == np.float32: + if dtype == np.float32: if sparse_input: qnDecisionFunctionSparse[float, int]( handle_[0], @@ -680,7 +717,7 @@ class QN(Base, X_m.nnz, n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, scores_ptr) else: @@ -691,7 +728,7 @@ class QN(Base, __is_col_major(X_m), n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, scores_ptr) @@ -706,7 +743,7 @@ class QN(Base, X_m.nnz, n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, scores_ptr) else: @@ -717,7 +754,7 @@ class QN(Base, __is_col_major(X_m), n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, scores_ptr) @@ -743,27 +780,31 @@ class QN(Base, Predicts the y for X. """ + coefs = self.coef_ + dtype = coefs.dtype + _num_classes_dim, n_cols = coefs.shape + sparse_input = is_sparse(X) + # Handle sparse inputs if sparse_input: X_m = SparseCumlArray( X, - convert_to_dtype=(self.dtype if convert_dtype else None), + convert_to_dtype=(dtype if convert_dtype else None), convert_index=np.int32 ) n_rows, n_cols = X_m.shape - self.dtype = X_m.dtype # Handle dense inputs else: - X_m, n_rows, n_cols, self.dtype = input_to_cuml_array( - X, check_dtype=self.dtype, - convert_to_dtype=(self.dtype if convert_dtype else None), - check_cols=self.n_cols, + X_m, n_rows, n_cols, dtype = input_to_cuml_array( + X, check_dtype=dtype, + convert_to_dtype=(dtype if convert_dtype else None), + check_cols=n_cols, order='K' ) - preds = CumlArray.zeros(shape=n_rows, dtype=self.dtype, + preds = CumlArray.zeros(shape=n_rows, dtype=dtype, index=X_m.index) cdef uintptr_t coef_ptr = self._coef_.ptr cdef uintptr_t pred_ptr = preds.ptr @@ -774,8 +815,25 @@ class QN(Base, cdef handle_t* handle_ = self.handle.getHandle() + if not hasattr(self, 'qnparams'): + self.qnparams = QNParams( + loss=self.loss, + penalty_l1=self.l1_strength, + penalty_l2=self.l2_strength, + grad_tol=self.tol, + change_tol=self.delta + if self.delta is not None else (self.tol * 0.01), + max_iter=self.max_iter, + linesearch_max_iter=self.linesearch_max_iter, + lbfgs_memory=self.lbfgs_memory, + verbose=self.verbose, + fit_intercept=self.fit_intercept, + penalty_normalized=self.penalty_normalized + ) + + _num_classes = self.get_num_classes(_num_classes_dim) cdef qn_params qnpams = self.qnparams.params - if self.dtype == np.float32: + if dtype == np.float32: if sparse_input: qnPredictSparse[float, int]( handle_[0], @@ -786,7 +844,7 @@ class QN(Base, X_m.nnz, n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, pred_ptr) else: @@ -797,7 +855,7 @@ class QN(Base, __is_col_major(X_m), n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, pred_ptr) @@ -812,7 +870,7 @@ class QN(Base, X_m.nnz, n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, pred_ptr) else: @@ -823,7 +881,7 @@ class QN(Base, __is_col_major(X_m), n_rows, n_cols, - self._num_classes, + _num_classes, coef_ptr, pred_ptr) @@ -838,6 +896,27 @@ class QN(Base, def score(self, X, y): return accuracy_score(y, self.predict(X)) + def get_num_classes(self, _num_classes_dim): + """ + Retrieves the number of classes from the classes dimension + in the coefficients. + """ + cdef qn_params qnpams = self.qnparams.params + solves_classification = qnpams.loss in { + qn_loss_type.QN_LOSS_LOGISTIC, + qn_loss_type.QN_LOSS_SOFTMAX, + qn_loss_type.QN_LOSS_SVC_L1, + qn_loss_type.QN_LOSS_SVC_L2 + } + solves_multiclass = qnpams.loss in { + qn_loss_type.QN_LOSS_SOFTMAX + } + if solves_classification and not solves_multiclass: + _num_classes = _num_classes_dim + 1 + else: + _num_classes = _num_classes_dim + return _num_classes + def _calc_intercept(self): """ If `fit_intercept == True`, then the last row of `coef_` contains @@ -845,7 +924,7 @@ class QN(Base, `coef_` """ - if (self.fit_intercept): + if self.fit_intercept: self.intercept_ = self._coef_[-1] else: self.intercept_ = CumlArray.zeros(shape=1) diff --git a/python/cuml/testing/test_preproc_utils.py b/python/cuml/testing/test_preproc_utils.py index a72589ee6d..7757bb1360 100644 --- a/python/cuml/testing/test_preproc_utils.py +++ b/python/cuml/testing/test_preproc_utils.py @@ -209,12 +209,21 @@ def int_dataset(request, random_seed): int(randint.size * 0.3), replace=False) - randint.ravel()[random_loc] = 0 - zero_filled = convert(randint, request.param) - randint.ravel()[random_loc] = 1 - one_filled = convert(randint, request.param) - randint.ravel()[random_loc] = cp.nan - nan_filled = convert(randint, request.param) + zero_filled = randint.copy().ravel() + zero_filled[random_loc] = 0 + zero_filled = zero_filled.reshape(randint.shape) + zero_filled = convert(zero_filled, request.param) + + one_filled = randint.copy().ravel() + one_filled[random_loc] = 1 + one_filled = one_filled.reshape(randint.shape) + one_filled = convert(one_filled, request.param) + + nan_filled = randint.copy().ravel() + nan_filled[random_loc] = cp.nan + nan_filled = nan_filled.reshape(randint.shape) + nan_filled = convert(nan_filled, request.param) + return zero_filled, one_filled, nan_filled diff --git a/python/cuml/testing/utils.py b/python/cuml/testing/utils.py index 4dc79c3580..aa952d1d4f 100644 --- a/python/cuml/testing/utils.py +++ b/python/cuml/testing/utils.py @@ -36,7 +36,6 @@ import cuml from cuml.common.input_utils import input_to_cuml_array, is_array_like from cuml.common.base import Base -from cuml.experimental.common.base import Base as experimentalBase import pytest @@ -403,7 +402,7 @@ def get_models(self): for name, cls in classes if cls not in self.exclude_classes and - issubclass(cls, (Base, experimentalBase)) + issubclass(cls, Base) } models.update(self.custom_constructors) return models diff --git a/python/cuml/tests/test_api.py b/python/cuml/tests/test_api.py index 07f8a2075a..6a50c4fa27 100644 --- a/python/cuml/tests/test_api.py +++ b/python/cuml/tests/test_api.py @@ -255,7 +255,15 @@ def test_fit_function(dataset, model_name): # and the inspect module doesn't work with Cython. Therefore we need # to register the number of arguments manually if `fit` is decorated pos_args_spec = { - "ARIMA": 1 + "ARIMA": 1, + "ElasticNet": 3, + "Lasso": 3, + "LinearRegression": 3, + "LogisticRegression": 3, + "NearestNeighbors": 2, + "PCA": 2, + "Ridge": 3, + "UMAP": 2 } n_pos_args_fit = ( pos_args_spec[model_name] diff --git a/python/cuml/tests/test_device_selection.py b/python/cuml/tests/test_device_selection.py index 62df94fc7a..3bec00a4a3 100644 --- a/python/cuml/tests/test_device_selection.py +++ b/python/cuml/tests/test_device_selection.py @@ -20,13 +20,36 @@ import numpy as np import pandas as pd import cudf -from sklearn.datasets import make_regression -from sklearn.linear_model import LinearRegression as skLinearRegression -from cuml.linear_model import LinearRegression +import pickle +import inspect +from importlib import import_module + +from pytest_cases import fixture_union, pytest_fixture_plus +from sklearn.datasets import make_regression, make_blobs +from cuml.metrics import trustworthiness + from cuml.testing.test_preproc_utils import to_output_type from cuml.common.device_selection import DeviceType, using_device_type from cuml.common.memory_utils import MemoryType, using_memory_type -import pickle + +from sklearn.linear_model import LinearRegression as skLinearRegression +from cuml.linear_model import LinearRegression +from sklearn.linear_model import LogisticRegression as skLogisticRegression +from cuml.linear_model import LogisticRegression +from sklearn.linear_model import Lasso as skLasso +from cuml.linear_model import Lasso +from sklearn.linear_model import ElasticNet as skElasticNet +from cuml.linear_model import ElasticNet +from sklearn.linear_model import Ridge as skRidge +from cuml.linear_model import Ridge +from umap import UMAP as refUMAP +from cuml.manifold import UMAP +from cuml.decomposition import PCA +from sklearn.decomposition import PCA as skPCA +from cuml.decomposition import TruncatedSVD +from sklearn.decomposition import TruncatedSVD as skTruncatedSVD +from cuml.neighbors import NearestNeighbors +from sklearn.neighbors import NearestNeighbors as skNearestNeighbors @pytest.mark.parametrize('input', [('cpu', DeviceType.host), @@ -61,9 +84,58 @@ def test_memory_type_exception(): assert True -X, y = make_regression(n_samples=2000, n_features=20, n_informative=15) -X_train, X_test = X[:1800], X[1800:] -y_train, _ = y[:1800], y[1800:] +def make_reg_dataset(): + X, y = make_regression(n_samples=2000, n_features=20, + n_informative=18, random_state=0) + X_train, X_test = X[:1800], X[1800:] + y_train, _ = y[:1800], y[1800:] + return X_train.astype(np.float32), y_train.astype(np.float32), \ + X_test.astype(np.float32) + + +def make_blob_dataset(): + X, y = make_blobs(n_samples=2000, n_features=20, + centers=20, random_state=0) + X_train, X_test = X[:1800], X[1800:] + y_train, _ = y[:1800], y[1800:] + return X_train.astype(np.float32), y_train.astype(np.float32), \ + X_test.astype(np.float32) + + +X_train_reg, y_train_reg, X_test_reg = make_reg_dataset() +X_train_blob, y_train_blob, X_test_blob = make_blob_dataset() + + +def check_trustworthiness(cuml_embedding, test_data): + X_test = to_output_type(test_data['X_test'], 'numpy') + cuml_embedding = to_output_type(cuml_embedding, 'numpy') + trust = trustworthiness(X_test, cuml_embedding, n_neighbors=10) + ref_trust = test_data['ref_trust'] + tol = 0.02 + assert trust >= ref_trust - tol + + +def check_allclose(cuml_output, test_data): + ref_output = to_output_type(test_data['ref_y_test'], 'numpy') + cuml_output = to_output_type(cuml_output, 'numpy') + np.testing.assert_allclose(ref_output, cuml_output, rtol=0.15) + + +def check_allclose_without_sign(cuml_output, test_data): + ref_output = to_output_type(test_data['ref_y_test'], 'numpy') + cuml_output = to_output_type(cuml_output, 'numpy') + assert ref_output.shape == cuml_output.shape + ref_output, cuml_output = np.abs(ref_output), np.abs(cuml_output) + np.testing.assert_allclose(ref_output, cuml_output, rtol=0.15) + + +def check_nn(cuml_output, test_data): + ref_dists = to_output_type(test_data['ref_y_test'][0], 'numpy') + ref_indices = to_output_type(test_data['ref_y_test'][1], 'numpy') + cuml_dists = to_output_type(cuml_output[0], 'numpy') + cuml_indices = to_output_type(cuml_output[1], 'numpy') + np.testing.assert_allclose(ref_indices, cuml_indices) + np.testing.assert_allclose(ref_dists, cuml_dists, rtol=0.15) def fixture_generation_helper(params): @@ -80,96 +152,418 @@ def fixture_generation_helper(params): } -@pytest.fixture(**fixture_generation_helper({ +@pytest_fixture_plus(**fixture_generation_helper({ 'input_type': ['numpy', 'dataframe', 'cupy', 'cudf', 'numba'], 'fit_intercept': [False, True], 'normalize': [False, True] })) -def lr_data(request): - sk_model = skLinearRegression(fit_intercept=request.param['fit_intercept'], - normalize=request.param['normalize']) - sk_model.fit(X_train, y_train) +def linreg_test_data(request): + kwargs = { + 'fit_intercept': request.param['fit_intercept'], + 'normalize': request.param['normalize'], + } + + sk_model = skLinearRegression(**kwargs) + sk_model.fit(X_train_reg, y_train_reg) input_type = request.param['input_type'] if input_type == 'dataframe': - modified_y_train = pd.Series(y_train) + modified_y_train = pd.Series(y_train_reg) elif input_type == 'cudf': - modified_y_train = cudf.Series(y_train) + modified_y_train = cudf.Series(y_train_reg) else: - modified_y_train = to_output_type(y_train, input_type) + modified_y_train = to_output_type(y_train_reg, input_type) return { + 'cuEstimator': LinearRegression, + 'kwargs': kwargs, + 'infer_func': 'predict', + 'assert_func': check_allclose, + 'X_train': to_output_type(X_train_reg, input_type), + 'y_train': modified_y_train, + 'X_test': to_output_type(X_test_reg, input_type), + 'ref_y_test': sk_model.predict(X_test_reg) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'penalty': ['none', 'l2'], + 'fit_intercept': [False, True] + })) +def logreg_test_data(request): + kwargs = { + 'penalty': request.param['penalty'], 'fit_intercept': request.param['fit_intercept'], - 'normalize': request.param['normalize'], - 'X_train': to_output_type(X_train, input_type), + 'max_iter': 1000 + } + + y_train_logreg = (y_train_reg > np.median(y_train_reg)).astype(np.int32) + + sk_model = skLogisticRegression(**kwargs) + sk_model.fit(X_train_reg, y_train_logreg) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + y_train_logreg = pd.Series(y_train_logreg) + elif input_type == 'cudf': + y_train_logreg = cudf.Series(y_train_logreg) + else: + y_train_logreg = to_output_type(y_train_logreg, input_type) + + return { + 'cuEstimator': LogisticRegression, + 'kwargs': kwargs, + 'infer_func': 'predict', + 'assert_func': check_allclose, + 'X_train': to_output_type(X_train_reg, input_type), + 'y_train': y_train_logreg, + 'X_test': to_output_type(X_test_reg, input_type), + 'ref_y_test': sk_model.predict(X_test_reg) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'fit_intercept': [False, True], + 'selection': ['cyclic', 'random'] + })) +def lasso_test_data(request): + kwargs = { + 'fit_intercept': request.param['fit_intercept'], + 'selection': request.param['selection'], + 'tol': 0.0001 + } + + sk_model = skLasso(**kwargs) + sk_model.fit(X_train_reg, y_train_reg) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + modified_y_train = pd.Series(y_train_reg) + elif input_type == 'cudf': + modified_y_train = cudf.Series(y_train_reg) + else: + modified_y_train = to_output_type(y_train_reg, input_type) + + return { + 'cuEstimator': Lasso, + 'kwargs': kwargs, + 'infer_func': 'predict', + 'assert_func': check_allclose, + 'X_train': to_output_type(X_train_reg, input_type), + 'y_train': modified_y_train, + 'X_test': to_output_type(X_test_reg, input_type), + 'ref_y_test': sk_model.predict(X_test_reg) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'fit_intercept': [False, True], + 'selection': ['cyclic', 'random'] + })) +def elasticnet_test_data(request): + kwargs = { + 'fit_intercept': request.param['fit_intercept'], + 'selection': request.param['selection'], + 'tol': 0.0001 + } + + sk_model = skElasticNet(**kwargs) + sk_model.fit(X_train_reg, y_train_reg) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + modified_y_train = pd.Series(y_train_reg) + elif input_type == 'cudf': + modified_y_train = cudf.Series(y_train_reg) + else: + modified_y_train = to_output_type(y_train_reg, input_type) + + return { + 'cuEstimator': ElasticNet, + 'kwargs': kwargs, + 'infer_func': 'predict', + 'assert_func': check_allclose, + 'X_train': to_output_type(X_train_reg, input_type), + 'y_train': modified_y_train, + 'X_test': to_output_type(X_test_reg, input_type), + 'ref_y_test': sk_model.predict(X_test_reg) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'fit_intercept': [False, True] + })) +def ridge_test_data(request): + kwargs = { + 'fit_intercept': request.param['fit_intercept'], + 'solver': 'svd' + } + + sk_model = skRidge(**kwargs) + sk_model.fit(X_train_reg, y_train_reg) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + modified_y_train = pd.Series(y_train_reg) + elif input_type == 'cudf': + modified_y_train = cudf.Series(y_train_reg) + else: + modified_y_train = to_output_type(y_train_reg, input_type) + + return { + 'cuEstimator': Ridge, + 'kwargs': kwargs, + 'infer_func': 'predict', + 'assert_func': check_allclose, + 'X_train': to_output_type(X_train_reg, input_type), + 'y_train': modified_y_train, + 'X_test': to_output_type(X_test_reg, input_type), + 'ref_y_test': sk_model.predict(X_test_reg) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['cupy'], + 'n_components': [2, 16], + 'init': ['spectral', 'random'] + })) +def umap_test_data(request): + kwargs = { + 'n_neighbors': 12, + 'n_components': request.param['n_components'], + 'init': request.param['init'], + 'random_state': 42 + } + + ref_model = refUMAP(**kwargs) + ref_model.fit(X_train_blob, y_train_blob) + ref_embedding = ref_model.transform(X_test_blob) + ref_trust = trustworthiness(X_test_blob, ref_embedding, n_neighbors=12) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + modified_y_train = pd.Series(y_train_blob) + elif input_type == 'cudf': + modified_y_train = cudf.Series(y_train_blob) + else: + modified_y_train = to_output_type(y_train_blob, input_type) + + return { + 'cuEstimator': UMAP, + 'kwargs': kwargs, + 'infer_func': 'transform', + 'assert_func': check_trustworthiness, + 'X_train': to_output_type(X_train_blob, input_type), 'y_train': modified_y_train, - 'X_test': to_output_type(X_test, input_type), - 'ref_y_test': sk_model.predict(X_test) + 'X_test': to_output_type(X_test_blob, input_type), + 'ref_trust': ref_trust } -def test_train_cpu_infer_cpu(lr_data): - model = LinearRegression(fit_intercept=lr_data['fit_intercept'], - normalize=lr_data['normalize']) +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'n_components': [2, 8] + })) +def pca_test_data(request): + kwargs = { + 'n_components': request.param['n_components'], + 'svd_solver': 'full', + 'tol': 1e-07, + 'iterated_power': 15 + } + + sk_model = skPCA(**kwargs) + sk_model.fit(X_train_blob, y_train_blob) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + modified_y_train = pd.Series(y_train_blob) + elif input_type == 'cudf': + modified_y_train = cudf.Series(y_train_blob) + else: + modified_y_train = to_output_type(y_train_blob, input_type) + + return { + 'cuEstimator': PCA, + 'kwargs': kwargs, + 'infer_func': 'transform', + 'assert_func': check_allclose_without_sign, + 'X_train': to_output_type(X_train_blob, input_type), + 'y_train': modified_y_train, + 'X_test': to_output_type(X_test_blob, input_type), + 'ref_y_test': sk_model.transform(X_test_blob) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'n_components': [2, 8] + })) +def tsvd_test_data(request): + kwargs = { + 'n_components': request.param['n_components'], + 'n_iter': 15, + 'tol': 1e-07 + } + + sk_model = skTruncatedSVD(**kwargs) + sk_model.fit(X_train_blob, y_train_blob) + + input_type = request.param['input_type'] + + if input_type == 'dataframe': + modified_y_train = pd.Series(y_train_blob) + elif input_type == 'cudf': + modified_y_train = cudf.Series(y_train_blob) + else: + modified_y_train = to_output_type(y_train_blob, input_type) + + return { + 'cuEstimator': TruncatedSVD, + 'kwargs': kwargs, + 'infer_func': 'transform', + 'assert_func': check_allclose_without_sign, + 'X_train': to_output_type(X_train_blob, input_type), + 'y_train': modified_y_train, + 'X_test': to_output_type(X_test_blob, input_type), + 'ref_y_test': sk_model.transform(X_test_blob) + } + + +@pytest_fixture_plus(**fixture_generation_helper({ + 'input_type': ['numpy', 'dataframe', 'cupy', + 'cudf', 'numba'], + 'metric': ['euclidean', 'cosine'], + 'n_neighbors': [3, 8] + })) +def nn_test_data(request): + kwargs = { + 'metric': request.param['metric'], + 'n_neighbors': request.param['n_neighbors'] + } + + sk_model = skNearestNeighbors(**kwargs) + sk_model.fit(X_train_blob) + + input_type = request.param['input_type'] + + return { + 'cuEstimator': NearestNeighbors, + 'kwargs': kwargs, + 'infer_func': 'kneighbors', + 'assert_func': check_nn, + 'X_train': to_output_type(X_train_blob, input_type), + 'X_test': to_output_type(X_test_blob, input_type), + 'ref_y_test': sk_model.kneighbors(X_test_blob) + } + + +fixture_union('test_data', ['linreg_test_data', + 'logreg_test_data', + 'lasso_test_data', + 'elasticnet_test_data', + 'ridge_test_data', + 'umap_test_data', + 'pca_test_data', + 'tsvd_test_data', + 'nn_test_data']) + + +def test_train_cpu_infer_cpu(test_data): + cuEstimator = test_data['cuEstimator'] + model = cuEstimator(**test_data['kwargs']) with using_device_type('cpu'): - model.fit(lr_data['X_train'], lr_data['y_train']) - cu_pred = model.predict(lr_data['X_test']) + if 'y_train' in test_data: + model.fit(test_data['X_train'], test_data['y_train']) + else: + model.fit(test_data['X_train']) + infer_func = getattr(model, test_data['infer_func']) + cuml_output = infer_func(test_data['X_test']) + + assert_func = test_data['assert_func'] + assert_func(cuml_output, test_data) - cu_pred = to_output_type(cu_pred, 'numpy').flatten() - np.testing.assert_allclose(lr_data['ref_y_test'], cu_pred) +def test_train_gpu_infer_cpu(test_data): + cuEstimator = test_data['cuEstimator'] + if cuEstimator is UMAP: + pytest.skip('UMAP GPU training CPU inference not yet implemented') -def test_train_gpu_infer_cpu(lr_data): - model = LinearRegression(fit_intercept=lr_data['fit_intercept'], - normalize=lr_data['normalize']) + model = cuEstimator(**test_data['kwargs']) with using_device_type('gpu'): - model.fit(lr_data['X_train'], lr_data['y_train']) + if 'y_train' in test_data: + model.fit(test_data['X_train'], test_data['y_train']) + else: + model.fit(test_data['X_train']) with using_device_type('cpu'): - cu_pred = model.predict(lr_data['X_test']) + infer_func = getattr(model, test_data['infer_func']) + cuml_output = infer_func(test_data['X_test']) - cu_pred = to_output_type(cu_pred, 'numpy').flatten() - np.testing.assert_allclose(lr_data['ref_y_test'], cu_pred) + assert_func = test_data['assert_func'] + assert_func(cuml_output, test_data) -def test_train_cpu_infer_gpu(lr_data): - model = LinearRegression(fit_intercept=lr_data['fit_intercept'], - normalize=lr_data['normalize']) +def test_train_cpu_infer_gpu(test_data): + cuEstimator = test_data['cuEstimator'] + model = cuEstimator(**test_data['kwargs']) with using_device_type('cpu'): - model.fit(lr_data['X_train'], lr_data['y_train']) + if 'y_train' in test_data: + model.fit(test_data['X_train'], test_data['y_train']) + else: + model.fit(test_data['X_train']) with using_device_type('gpu'): - cu_pred = model.predict(lr_data['X_test']) + infer_func = getattr(model, test_data['infer_func']) + cuml_output = infer_func(test_data['X_test']) - cu_pred = to_output_type(cu_pred, 'numpy').flatten() - np.testing.assert_allclose(lr_data['ref_y_test'], cu_pred) + assert_func = test_data['assert_func'] + assert_func(cuml_output, test_data) -def test_train_gpu_infer_gpu(lr_data): - model = LinearRegression(fit_intercept=lr_data['fit_intercept'], - normalize=lr_data['normalize']) +def test_train_gpu_infer_gpu(test_data): + cuEstimator = test_data['cuEstimator'] + model = cuEstimator(**test_data['kwargs']) with using_device_type('gpu'): - model.fit(lr_data['X_train'], lr_data['y_train']) - cu_pred = model.predict(lr_data['X_test']) + if 'y_train' in test_data: + model.fit(test_data['X_train'], test_data['y_train']) + else: + model.fit(test_data['X_train']) + infer_func = getattr(model, test_data['infer_func']) + cuml_output = infer_func(test_data['X_test']) - cu_pred = to_output_type(cu_pred, 'numpy').flatten() - np.testing.assert_allclose(lr_data['ref_y_test'], cu_pred) - sk_model = skLinearRegression(fit_intercept=lr_data['fit_intercept'], - normalize=lr_data['normalize']) - sk_model.fit(X_train, y_train) - sk_pred = sk_model.predict(X_test) - np.testing.assert_allclose(sk_pred, cu_pred) + assert_func = test_data['assert_func'] + assert_func(cuml_output, test_data) -@pytest.mark.parametrize('fit_intercept', [False, True]) -@pytest.mark.parametrize('normalize', [False, True]) -def test_pickle_interop(fit_intercept, normalize): +def test_pickle_interop(test_data): pickle_filepath = '/tmp/model.pickle' - model = LinearRegression(fit_intercept=fit_intercept, - normalize=normalize) + cuEstimator = test_data['cuEstimator'] + if cuEstimator is UMAP: + pytest.skip('UMAP GPU training CPU inference not yet implemented') + model = cuEstimator(**test_data['kwargs']) with using_device_type('gpu'): - model.fit(X_train, y_train) + if 'y_train' in test_data: + model.fit(test_data['X_train'], test_data['y_train']) + else: + model.fit(test_data['X_train']) with open(pickle_filepath, 'wb') as pf: pickle.dump(model, pf) @@ -180,10 +574,221 @@ def test_pickle_interop(fit_intercept, normalize): pickled_model = pickle.load(pf) with using_device_type('cpu'): - cu_pred = pickled_model.predict(X_test) + infer_func = getattr(pickled_model, test_data['infer_func']) + cuml_output = infer_func(test_data['X_test']) + + assert_func = test_data['assert_func'] + assert_func(cuml_output, test_data) + + +@pytest.mark.skip('Hyperparameters defaults understandably different') +@pytest.mark.parametrize('estimator', [LinearRegression, + LogisticRegression, + Lasso, + ElasticNet, + Ridge, + UMAP, + PCA, + TruncatedSVD, + NearestNeighbors]) +def test_hyperparams_defaults(estimator): + model = estimator() + cu_signature = inspect.signature(model.__init__).parameters + + if hasattr(model, '_cpu_estimator_import_path'): + model_path = model._cpu_estimator_import_path + else: + model_path = 'sklearn' + model.__class__.__module__[4:] + model_name = model.__class__.__name__ + cpu_model = getattr(import_module(model_path), model_name) + cpu_signature = inspect.signature(cpu_model.__init__).parameters + + common_hyperparams = list(set(cu_signature.keys()) & + set(cpu_signature.keys())) + error_msg = 'Different default values for hyperparameters:\n' + similar = True + for hyperparam in common_hyperparams: + if cu_signature[hyperparam].default != \ + cpu_signature[hyperparam].default: + similar = False + error_msg += "\t{} with cuML default :" \ + "'{}' and CPU default : '{}'" \ + "\n".format(hyperparam, + cu_signature[hyperparam].default, + cpu_signature[hyperparam].default) + + if not similar: + raise ValueError(error_msg) + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_linreg_methods(train_device, infer_device): + ref_model = skLinearRegression() + ref_model.fit(X_train_reg, y_train_reg) + ref_output = ref_model.score(X_train_reg, y_train_reg) + + model = LinearRegression() + with using_device_type(train_device): + model.fit(X_train_reg, y_train_reg) + with using_device_type(infer_device): + output = model.score(X_train_reg, y_train_reg) + + tol = 0.01 + assert ref_output - tol <= output <= ref_output + tol + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_func_name', ['decision_function', + 'predict_proba', + 'predict_log_proba', + 'score']) +def test_logreg_methods(train_device, infer_device, infer_func_name): + y_train_logreg = (y_train_reg > np.median(y_train_reg)).astype(np.int32) + + ref_model = skLogisticRegression() + ref_model.fit(X_train_reg, y_train_logreg) + infer_func = getattr(ref_model, infer_func_name) + if infer_func_name == 'score': + ref_output = infer_func(X_train_reg, y_train_logreg) + else: + ref_output = infer_func(X_test_reg) + + model = LogisticRegression() + with using_device_type(train_device): + model.fit(X_train_reg, y_train_logreg) + with using_device_type(infer_device): + infer_func = getattr(model, infer_func_name) + if infer_func_name == 'score': + output = infer_func(X_train_reg.astype(np.float64), + y_train_logreg.astype(np.float64)) + else: + output = infer_func(X_test_reg.astype(np.float64)) + + if infer_func_name == 'score': + tol = 0.01 + assert ref_output - tol <= output <= ref_output + tol + else: + output = to_output_type(output, 'numpy') + mask = np.isfinite(output) + np.testing.assert_allclose(ref_output[mask], output[mask], + atol=0.1, rtol=0.15) + - sk_model = skLinearRegression(fit_intercept=fit_intercept, - normalize=normalize) - sk_model.fit(X_train, y_train) - sk_pred = sk_model.predict(X_test) - np.testing.assert_allclose(sk_pred, cu_pred) +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_lasso_methods(train_device, infer_device): + ref_model = skLasso() + ref_model.fit(X_train_reg, y_train_reg) + ref_output = ref_model.score(X_train_reg, y_train_reg) + + model = Lasso() + with using_device_type(train_device): + model.fit(X_train_reg, y_train_reg) + with using_device_type(infer_device): + output = model.score(X_train_reg, y_train_reg) + + tol = 0.01 + assert ref_output - tol <= output <= ref_output + tol + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_elasticnet_methods(train_device, infer_device): + ref_model = skElasticNet() + ref_model.fit(X_train_reg, y_train_reg) + ref_output = ref_model.score(X_train_reg, y_train_reg) + + model = ElasticNet() + with using_device_type(train_device): + model.fit(X_train_reg, y_train_reg) + with using_device_type(infer_device): + output = model.score(X_train_reg, y_train_reg) + + tol = 0.01 + assert ref_output - tol <= output <= ref_output + tol + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_ridge_methods(train_device, infer_device): + ref_model = skRidge() + ref_model.fit(X_train_reg, y_train_reg) + ref_output = ref_model.score(X_train_reg, y_train_reg) + + model = Ridge() + with using_device_type(train_device): + model.fit(X_train_reg, y_train_reg) + with using_device_type(infer_device): + output = model.score(X_train_reg, y_train_reg) + + tol = 0.01 + assert ref_output - tol <= output <= ref_output + tol + + +@pytest.mark.parametrize('device', ['cpu', 'gpu']) +def test_umap_methods(device): + ref_model = refUMAP(n_neighbors=12) + ref_embedding = ref_model.fit_transform(X_train_blob, y_train_blob) + ref_trust = trustworthiness(X_train_blob, ref_embedding, n_neighbors=12) + + model = UMAP(n_neighbors=12) + with using_device_type(device): + embedding = model.fit_transform(X_train_blob, y_train_blob) + trust = trustworthiness(X_train_blob, embedding, n_neighbors=12) + + tol = 0.02 + assert ref_trust - tol <= trust <= ref_trust + tol + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_pca_methods(train_device, infer_device): + n, p = 500, 5 + rng = np.random.RandomState(0) + X = rng.randn(n, p) * .1 + np.array([3, 4, 2, 3, 5]) + + model = PCA(n_components=3) + with using_device_type(train_device): + transformation = model.fit_transform(X) + with using_device_type(infer_device): + output = model.inverse_transform(transformation) + + output = to_output_type(output, 'numpy') + np.testing.assert_allclose(X, output, rtol=0.15) + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_tsvd_methods(train_device, infer_device): + n, p = 500, 5 + rng = np.random.RandomState(0) + X = rng.randn(n, p) * .1 + np.array([3, 4, 2, 3, 5]) + + model = TruncatedSVD(n_components=3) + with using_device_type(train_device): + transformation = model.fit_transform(X) + with using_device_type(infer_device): + output = model.inverse_transform(transformation) + + output = to_output_type(output, 'numpy') + np.testing.assert_allclose(X, output, rtol=0.15) + + +@pytest.mark.parametrize('train_device', ['cpu', 'gpu']) +@pytest.mark.parametrize('infer_device', ['cpu', 'gpu']) +def test_nn_methods(train_device, infer_device): + ref_model = skNearestNeighbors() + ref_model.fit(X_train_blob) + ref_output = ref_model.kneighbors_graph(X_train_blob) + + model = NearestNeighbors() + with using_device_type(train_device): + model.fit(X_train_blob) + with using_device_type(infer_device): + output = model.kneighbors_graph(X_train_blob) + + ref_output = ref_output.todense() + output = output.todense() + np.testing.assert_allclose(ref_output, output, rtol=0.15) diff --git a/python/cuml/tests/test_linear_model.py b/python/cuml/tests/test_linear_model.py index 77d809a9d7..c8f181baeb 100644 --- a/python/cuml/tests/test_linear_model.py +++ b/python/cuml/tests/test_linear_model.py @@ -636,7 +636,7 @@ def test_logistic_regression_decision_function( culog.fit(X_train, y_train) sklog = skLog(fit_intercept=fit_intercept) - sklog.coef_ = culog.coef_.T + sklog.coef_ = culog.coef_ if fit_intercept: sklog.intercept_ = culog.intercept_ else: @@ -683,7 +683,7 @@ def test_logistic_regression_predict_proba( ) else: sklog = skLog(fit_intercept=fit_intercept) - sklog.coef_ = culog.coef_.T + sklog.coef_ = culog.coef_ if fit_intercept: sklog.intercept_ = culog.intercept_ else: @@ -822,7 +822,6 @@ def test_logistic_regression_weighting(regression_dataset, unit_tol = 0.04 total_tol = 0.08 elif regression_type.startswith('multiclass'): - skcoef = skcoef.T skcoef /= np.linalg.norm(skcoef, axis=1)[:, None] cucoef /= np.linalg.norm(cucoef, axis=1)[:, None] unit_tol = 0.2 diff --git a/python/cuml/tests/test_nearest_neighbors.py b/python/cuml/tests/test_nearest_neighbors.py index 72dd725c72..4f67466c65 100644 --- a/python/cuml/tests/test_nearest_neighbors.py +++ b/python/cuml/tests/test_nearest_neighbors.py @@ -370,7 +370,7 @@ def test_knn_separate_index_search(input_type, nrows, n_feats, k, metric): with cuml.using_output_type("numpy"): # Assert the cuml model was properly reverted - np.testing.assert_allclose(knn_cu.X_m, X_orig.get(), + np.testing.assert_allclose(knn_cu._fit_X, X_orig.get(), atol=1e-3, rtol=1e-3) if metric == 'braycurtis': @@ -411,7 +411,7 @@ def test_knn_x_none(input_type, nrows, n_feats, k, metric): D_cuml, I_cuml = knn_cu.kneighbors(X=None, n_neighbors=k) # Assert the cuml model was properly reverted - cp.testing.assert_allclose(knn_cu.X_m, X_orig, + cp.testing.assert_allclose(knn_cu._fit_X, X_orig, atol=1e-5, rtol=1e-4) # Allow a max relative diff of 10% and absolute diff of 1% diff --git a/python/cuml/tests/test_pickle.py b/python/cuml/tests/test_pickle.py index bd0ce3ae51..f1c2be7882 100644 --- a/python/cuml/tests/test_pickle.py +++ b/python/cuml/tests/test_pickle.py @@ -92,6 +92,7 @@ 'OneVsOneClassifier', 'OneVsRestClassifier', "SparseRandomProjection", + "UMAP" ] all_models = get_classes_from_package(cuml, import_sub_packages=True) @@ -468,7 +469,7 @@ def assert_model(pickled_model, X_test): assert array_equal(result["neighbors"], D_after) state = pickled_model.__dict__ assert state["n_indices"] == 1 - assert "X_m" in state + assert "_fit_X" in state pickle_save_load(tmpdir, create_mod, assert_model) @@ -495,13 +496,13 @@ def create_mod(): def assert_model(loaded_model, X): state = loaded_model.__dict__ assert state["n_indices"] == 0 - assert "X_m" not in state + assert "_fit_X" not in state loaded_model.fit(X[0]) state = loaded_model.__dict__ assert state["n_indices"] == 1 - assert "X_m" in state + assert "_fit_X" in state pickle_save_load(tmpdir, create_mod, assert_model) diff --git a/python/cuml/tests/test_qn.py b/python/cuml/tests/test_qn.py index a4796b9582..4935510eb3 100644 --- a/python/cuml/tests/test_qn.py +++ b/python/cuml/tests/test_qn.py @@ -75,15 +75,13 @@ def test_qn(loss, dtype, penalty, l1_strength, l2_strength, fit_intercept): assert (qn.objective - 0.40263831615448) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.1088872], - [2.4812558]]), + np.array([[-2.1088872, 2.4812558]]), decimal=3) else: assert (qn.objective - 0.4317452311515808) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.120777], - [3.056865]]), + np.array([[-2.120777, 3.056865]]), decimal=3) elif penalty == 'l1' and l2_strength == 0.0: if fit_intercept: @@ -91,15 +89,13 @@ def test_qn(loss, dtype, penalty, l1_strength, l2_strength, fit_intercept): assert (qn.objective - 0.40263831615448) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.1088872], - [2.4812558]]), + np.array([[-2.1088872, 2.4812558]]), decimal=3) else: assert (qn.objective - 0.44295936822891235) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.6899368], - [1.9021575]]), + np.array([[-1.6899368, 1.9021575]]), decimal=3) else: @@ -107,16 +103,14 @@ def test_qn(loss, dtype, penalty, l1_strength, l2_strength, fit_intercept): assert (qn.objective - 0.4317452311515808) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.120777], - [3.056865]]), + np.array([[-2.120777, 3.056865]]), decimal=3) else: assert (qn.objective - 0.4769895672798157) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.6214856], - [2.3650239]]), + np.array([[-1.6214856, 2.3650239]]), decimal=3) # assert False @@ -127,15 +121,13 @@ def test_qn(loss, dtype, penalty, l1_strength, l2_strength, fit_intercept): assert (qn.objective - 0.40263831615448) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.1088872], - [2.4812558]]), + np.array([[-2.1088872, 2.4812558]]), decimal=3) else: assert (qn.objective - 0.43780848383903503) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.5337948], - [1.678699]]), + np.array([[-1.5337948, 1.678699]]), decimal=3) else: @@ -143,16 +135,14 @@ def test_qn(loss, dtype, penalty, l1_strength, l2_strength, fit_intercept): assert (qn.objective - 0.4317452311515808) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.120777], - [3.056865]]), + np.array([[-2.120777, 3.056865]]), decimal=3) else: assert (qn.objective - 0.4750209450721741) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.3931049], - [2.0140104]]), + np.array([[-1.3931049, 2.0140104]]), decimal=3) if penalty == 'elasticnet': @@ -161,59 +151,51 @@ def test_qn(loss, dtype, penalty, l1_strength, l2_strength, fit_intercept): assert (qn.objective - 0.40263831615448) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.1088872], - [2.4812558]]), + np.array([[-2.1088872, 2.4812558]]), decimal=3) elif l1_strength == 0.0: assert (qn.objective - 0.43780848383903503) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.5337948], - [1.678699]]), + np.array([[-1.5337948, 1.678699]]), decimal=3) elif l2_strength == 0.0: assert (qn.objective - 0.44295936822891235) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.6899368], - [1.9021575]]), + np.array([[-1.6899368, 1.9021575]]), decimal=3) else: assert (qn.objective - 0.467987984418869) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.3727235], - [1.4639963]]), + np.array([[-1.3727235, 1.4639963]]), decimal=3) else: if l1_strength == 0.0 and l2_strength == 0.0: assert (qn.objective - 0.4317452311515808) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-2.120777], - [3.056865]]), + np.array([[-2.120777, 3.056865]]), decimal=3) elif l1_strength == 0.0: assert (qn.objective - 0.4750209450721741) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.3931049], - [2.0140104]]), + np.array([[-1.3931049, 2.0140104]]), decimal=3) elif l2_strength == 0.0: assert (qn.objective - 0.4769895672798157) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.6214856], - [2.3650239]]), + np.array([[-1.6214856, 2.3650239]]), decimal=3) else: assert (qn.objective - 0.5067970156669617) < tol cp.testing.assert_array_almost_equal( qn.coef_, - np.array([[-1.2102532], - [1.752459]]), + np.array([[-1.2102532, 1.752459]]), decimal=3) print() diff --git a/wiki/python/ESTIMATOR_GUIDE.md b/wiki/python/ESTIMATOR_GUIDE.md index 4c8f1fd1f6..626c48ba0c 100644 --- a/wiki/python/ESTIMATOR_GUIDE.md +++ b/wiki/python/ESTIMATOR_GUIDE.md @@ -56,13 +56,13 @@ At a high level, all cuML Estimators must: super().__init__(handle=handle, verbose=verbose, output_type=output_type) ... ``` -4. Declare each array-like attribute the new Estimator will compute as a class variable for automatic array type conversion +4. Declare each array-like attribute the new Estimator will compute as a class variable for automatic array type conversion. An order can be specified to serve as an indicator of the order the array should be in for the C++ algorithms to work. ```python from cuml.common.array_descriptor import CumlArrayDescriptor class MyEstimator(Base): - labels_ = CumlArrayDescriptor() + labels_ = CumlArrayDescriptor(order='C') def __init__(self): ... @@ -248,7 +248,7 @@ Performing the arrray conversion lazily (i.e. converting the input array to the #### Defining Array-Like Attributes -To use the `CumlArrayDescriptor` in an estimator, any array-like attributes need to be specified by creating a `CumlArrayDescriptor` as a class variable. +To use the `CumlArrayDescriptor` in an estimator, any array-like attributes need to be specified by creating a `CumlArrayDescriptor` as a class variable. An order can be specified to serve as an indicator of the order the array should be in for the C++ algorithms to work. ```python from cuml.common.array_descriptor import CumlArrayDescriptor @@ -256,7 +256,7 @@ from cuml.common.array_descriptor import CumlArrayDescriptor class TestEstimator(cuml.Base): # Class variables outside of any function - my_cuml_array_ = CumlArrayDescriptor() + my_cuml_array_ = CumlArrayDescriptor(order='C') def __init__(self, ...): ...